Validate intra-account messages before applying number changes

This commit is contained in:
Jon Chambers
2025-07-17 11:34:50 -04:00
committed by GitHub
parent 50bc6b2c62
commit 4ccd39fd55
8 changed files with 404 additions and 191 deletions

View File

@@ -700,7 +700,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager);
final MessageSender messageSender = new MessageSender(messagesManager, pushNotificationManager);
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = new CloudflareTurnCredentialsManager(
config.getTurnConfiguration().cloudflare().apiToken().value(),

View File

@@ -429,7 +429,8 @@ public class MessageController {
isStory,
messages.online(),
messages.urgent(),
spamCheckResult.token().orElse(null));
spamCheckResult.token().orElse(null),
clock);
} catch (final IllegalArgumentException e) {
logger.warn("Received bad envelope type {} from {}", message.type(), userAgent);
throw new BadRequestException(e);

View File

@@ -12,6 +12,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.time.Clock;
import java.util.Arrays;
import java.util.Objects;
import javax.annotation.Nullable;
@@ -40,14 +41,15 @@ public record IncomingMessage(int type,
final boolean story,
final boolean ephemeral,
final boolean urgent,
@Nullable byte[] reportSpamToken) {
@Nullable byte[] reportSpamToken,
final Clock clock) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder();
envelopeBuilder
.setType(MessageProtos.Envelope.Type.forNumber(type))
.setClientTimestamp(timestamp)
.setServerTimestamp(System.currentTimeMillis())
.setServerTimestamp(clock.millis())
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
.setStory(story)
.setEphemeral(ephemeral)

View File

@@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
@@ -53,7 +52,6 @@ public class MessageSender {
private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
// Note that these names deliberately reference `MessageController` for metric continuity
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage");
@@ -75,13 +73,9 @@ public class MessageSender {
@VisibleForTesting
static final byte NO_EXCLUDED_DEVICE_ID = -1;
public MessageSender(
final MessagesManager messagesManager,
final PushNotificationManager pushNotificationManager,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager;
this.experimentEnrollmentManager = experimentEnrollmentManager;
}
/**
@@ -109,44 +103,14 @@ public class MessageSender {
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId,
@Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException {
if (!destination.isIdentifiedBy(destinationIdentifier)) {
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
}
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
if (messagesByDeviceId.isEmpty()) {
Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME,
Tags.of(SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment();
}
final byte excludedDeviceId;
if (syncMessageSenderDeviceId.isPresent()) {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) ||
!destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender");
}
excludedDeviceId = syncMessageSenderDeviceId.get();
} else {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender");
}
excludedDeviceId = NO_EXCLUDED_DEVICE_ID;
}
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
validateIndividualMessageBundle(destination,
destinationIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
excludedDeviceId);
if (maybeMismatchedDevices.isPresent()) {
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
}
validateIndividualMessageContentLength(messagesByDeviceId.values(), syncMessageSenderDeviceId.isPresent(), userAgent);
syncMessageSenderDeviceId,
platformTag);
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
.forEach((deviceId, destinationPresent) -> {
@@ -206,7 +170,9 @@ public class MessageSender {
final boolean isUrgent,
@Nullable final String userAgent) throws MultiRecipientMismatchedDevicesException, MessageTooLargeException {
validateMultiRecipientMessageContentLength(multiRecipientMessage, isStory, userAgent);
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
validateMultiRecipientMessageContentLength(multiRecipientMessage, isStory, platformTag);
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier = new HashMap<>();
@@ -252,24 +218,105 @@ public class MessageSender {
SEALED_SENDER_TAG_NAME, "true",
SYNC_MESSAGE_TAG_NAME, "false",
MULTI_RECIPIENT_TAG_NAME, "true")
.and(UserAgentTagUtil.getPlatformTag(userAgent));
.and(platformTag);
Metrics.counter(SEND_COUNTER_NAME, tags).increment();
})))
.thenRun(Util.NOOP);
}
/**
* Validates that a bundle of messages destined for an individual account is well-formed and may be delivered. Note
* that all checks performed by this method are also performed by
* {@link #sendMessages(Account, ServiceIdentifier, Map, Map, Optional, String)}; callers should only invoke this
* method if they need to verify that a bundle of individual messages is valid <em>before</em> trying to send the
* messages (i.e. if the caller must take some other action in conjunction with sending the messages and cannot
* reverse that action if message sending fails).
*
* @param destination the account to which to send messages
* @param destinationIdentifier the service identifier to which the messages are addressed
* @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
* @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the
* caller's own account), contains the ID of the device that sent the message
* @param userAgent the User-Agent string for the sender; may be {@code null} if not known
*
* @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required
* devices, contained messages for devices not linked to the destination account, or devices with outdated
* registration IDs
* @throws MessageTooLargeException if the given message payload is too large
*/
public static void validateIndividualMessageBundle(final Account destination,
final ServiceIdentifier destinationIdentifier,
final Map<Byte, Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId,
@Nullable final String userAgent) throws MessageTooLargeException, MismatchedDevicesException {
validateIndividualMessageBundle(destination,
destinationIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
syncMessageSenderDeviceId,
UserAgentTagUtil.getPlatformTag(userAgent));
}
private static void validateIndividualMessageBundle(final Account destination,
final ServiceIdentifier destinationIdentifier,
final Map<Byte, Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId,
final Tag platformTag) throws MismatchedDevicesException, MessageTooLargeException {
if (!destination.isIdentifiedBy(destinationIdentifier)) {
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
}
if (messagesByDeviceId.isEmpty()) {
Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME,
Tags.of(SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment();
}
final byte excludedDeviceId;
if (syncMessageSenderDeviceId.isPresent()) {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) ||
!destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender");
}
excludedDeviceId = syncMessageSenderDeviceId.get();
} else {
if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) {
throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender");
}
excludedDeviceId = NO_EXCLUDED_DEVICE_ID;
}
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
destinationIdentifier,
registrationIdsByDeviceId,
excludedDeviceId);
if (maybeMismatchedDevices.isPresent()) {
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
}
validateIndividualMessageContentLength(messagesByDeviceId.values(), syncMessageSenderDeviceId.isPresent(), platformTag);
}
@VisibleForTesting
static void validateContentLength(final int contentLength,
final boolean isMultiRecipientMessage,
final boolean isSyncMessage,
final boolean isStory,
final String userAgent) throws MessageTooLargeException {
final Tag platformTag) throws MessageTooLargeException {
final boolean oversize = contentLength > MAX_MESSAGE_SIZE;
DistributionSummary.builder(CONTENT_SIZE_DISTRIBUTION_NAME)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
.tags(Tags.of(platformTag,
Tag.of("oversize", String.valueOf(oversize)),
Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)),
Tag.of("syncMessage", String.valueOf(isSyncMessage)),
@@ -279,7 +326,7 @@ public class MessageSender {
.record(contentLength);
if (oversize) {
Metrics.counter(REJECT_OVERSIZE_MESSAGE_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Metrics.counter(REJECT_OVERSIZE_MESSAGE_COUNTER_NAME, Tags.of(platformTag,
Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)),
Tag.of("syncMessage", String.valueOf(isSyncMessage)),
Tag.of("story", String.valueOf(isStory))))
@@ -330,27 +377,27 @@ public class MessageSender {
private static void validateIndividualMessageContentLength(final Iterable<Envelope> messages,
final boolean isSyncMessage,
@Nullable final String userAgent) throws MessageTooLargeException {
final Tag platformTag) throws MessageTooLargeException {
for (final Envelope message : messages) {
MessageSender.validateContentLength(message.getContent().size(),
false,
isSyncMessage,
message.getStory(),
userAgent);
platformTag);
}
}
private static void validateMultiRecipientMessageContentLength(final SealedSenderMultiRecipientMessage multiRecipientMessage,
final boolean isStory,
@Nullable final String userAgent) throws MessageTooLargeException {
final Tag platformTag) throws MessageTooLargeException {
for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) {
MessageSender.validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipient),
true,
false,
isStory,
userAgent);
platformTag);
}
}
}

View File

@@ -4,13 +4,11 @@
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import java.time.Clock;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ObjectUtils;
import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -20,7 +18,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
@@ -51,48 +49,57 @@ public class ChangeNumberManager {
final String senderUserAgent)
throws InterruptedException, MismatchedDevicesException, MessageTooLargeException {
final long serverTimestamp = clock.millis();
final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getIdentifier(IdentityType.ACI));
// Note that these for-validation envelopes do NOT have the "updated PNI" field set, and we'll need to populate that
// after actually changing the account's number.
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> message.toEnvelope(serviceIdentifier,
serviceIdentifier,
Device.PRIMARY_ID,
serverTimestamp,
false,
false,
true,
null,
clock)));
final Map<Byte, Integer> registrationIdsByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
// Make sure we can plausibly deliver the messages to other devices on the account before making any changes to the
// account itself
if (!messagesByDeviceId.isEmpty()) {
MessageSender.validateIndividualMessageBundle(account,
serviceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
Optional.of(Device.PRIMARY_ID),
senderUserAgent);
}
final Account updatedAccount = accountsManager.changeNumber(
account, number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds);
sendDeviceMessages(updatedAccount, deviceMessages, senderUserAgent);
return updatedAccount;
}
private void sendDeviceMessages(final Account account,
final List<IncomingMessage> deviceMessages,
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
try {
final long serverTimestamp = clock.millis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
// Now that we've actually updated the account, populate the "updated PNI" field on all envelopes
final String updatedPniString = updatedAccount.getIdentifier(IdentityType.PNI).toString();
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder()
.setType(Envelope.Type.forNumber(message.type()))
.setClientTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setContent(ByteString.copyFrom(message.content()))
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(account.getPhoneNumberIdentifier().toString())
.setUrgent(true)
.setEphemeral(false)
.build()));
messagesByDeviceId.replaceAll((deviceId, envelope) ->
envelope.toBuilder().setUpdatedPni(updatedPniString).build());
final Map<Byte, Integer> registrationIdsByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
messageSender.sendMessages(account,
messageSender.sendMessages(updatedAccount,
serviceIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
Optional.of(Device.PRIMARY_ID),
senderUserAgent);
} catch (final RuntimeException e) {
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
logger.warn("Changed number but could not send all device messages for {}", account.getIdentifier(IdentityType.ACI), e);
throw e;
}
return updatedAccount;
}
}