Internalize destination device list/registration ID checks in MessageSender

This commit is contained in:
Jon Chambers
2025-04-07 09:15:39 -04:00
committed by GitHub
parent 1d0e2d29a7
commit c6689ca07a
21 changed files with 675 additions and 755 deletions

View File

@@ -11,13 +11,24 @@ import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.util.DataSize;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -58,6 +69,9 @@ public class MessageSender {
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes();
@VisibleForTesting
static final byte NO_EXCLUDED_DEVICE_ID = -1;
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager;
@@ -68,23 +82,51 @@ public class MessageSender {
* notification token and does not have an active connection to a Signal server, then this method will also send a
* push notification to that device to announce the availability of new messages.
*
* @param account the account to which to send messages
* @param destination the account to which to send messages
* @param destinationIdentifier the service identifier to which the messages are addressed
* @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
*/
public void sendMessages(final Account account, final Map<Byte, Envelope> messagesByDeviceId) {
messagesManager.insert(account.getIdentifier(IdentityType.ACI), messagesByDeviceId)
public void sendMessages(final Account destination,
final ServiceIdentifier destinationIdentifier,
final Map<Byte, Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId) throws MismatchedDevicesException {
if (messagesByDeviceId.isEmpty()) {
return;
}
if (!destination.isIdentifiedBy(destinationIdentifier)) {
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
}
final Envelope firstMessage = messagesByDeviceId.values().iterator().next();
final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId()));
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
destinationIdentifier,
registrationIdsByDeviceId,
isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID);
if (maybeMismatchedDevices.isPresent()) {
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
}
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
.forEach((deviceId, destinationPresent) -> {
final Envelope message = messagesByDeviceId.get(deviceId);
if (!destinationPresent && !message.getEphemeral()) {
try {
pushNotificationManager.sendNewMessageNotification(account, deviceId, message.getUrgent());
pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent());
} catch (final NotPushRegisteredException ignored) {
}
}
Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME, account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
CHANNEL_TAG_NAME, destination.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()),
CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
@@ -98,6 +140,10 @@ public class MessageSender {
* Sends messages to a group of recipients. If a destination device has a valid push notification token and does not
* have an active connection to a Signal server, then this method will also send a push notification to that device to
* announce the availability of new messages.
* <p>
* This method sends messages to all <em>resolved</em> recipients. In some cases, a caller may not be able to resolve
* all recipients to active accounts, but may still choose to send the message. Callers are responsible for rejecting
* the message if they require full resolution of all recipients, but some recipients could not be resolved.
*
* @param multiRecipientMessage the multi-recipient message to send to the given recipients
* @param resolvedRecipients a map of recipients to resolved Signal accounts
@@ -114,7 +160,31 @@ public class MessageSender {
final long clientTimestamp,
final boolean isStory,
final boolean isEphemeral,
final boolean isUrgent) {
final boolean isUrgent) throws MultiRecipientMismatchedDevicesException {
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier = new HashMap<>();
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
if (!resolvedRecipients.containsKey(recipient)) {
// Callers are responsible for rejecting messages if they're missing recipients in a problematic way. If we run
// into an unresolved recipient here, just skip it.
return;
}
final Account account = resolvedRecipients.get(recipient);
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromLibsignal(serviceId);
final Map<Byte, Integer> registrationIdsByDeviceId = recipient.getDevicesAndRegistrationIds()
.collect(Collectors.toMap(Pair::first, pair -> (int) pair.second()));
getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, NO_EXCLUDED_DEVICE_ID)
.ifPresent(mismatchedDevices ->
mismatchedDevicesByServiceIdentifier.put(serviceIdentifier, mismatchedDevices));
});
if (!mismatchedDevicesByServiceIdentifier.isEmpty()) {
throw new MultiRecipientMismatchedDevicesException(mismatchedDevicesByServiceIdentifier);
}
return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp,
isStory, isEphemeral, isUrgent)
@@ -189,4 +259,47 @@ public class MessageSender {
.increment();
}
}
@VisibleForTesting
static Optional<MismatchedDevices> getMismatchedDevices(final Account account,
final ServiceIdentifier serviceIdentifier,
final Map<Byte, Integer> registrationIdsByDeviceId,
final byte excludedDeviceId) {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != excludedDeviceId)
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(registrationIdsByDeviceId.keySet());
final Set<Byte> extraDeviceIds = new HashSet<>(registrationIdsByDeviceId.keySet());
extraDeviceIds.removeAll(accountDeviceIds);
final Set<Byte> staleDeviceIds = registrationIdsByDeviceId.entrySet().stream()
// Filter out device IDs that aren't associated with the given account
.filter(entry -> !extraDeviceIds.contains(entry.getKey()))
.filter(entry -> {
final byte deviceId = entry.getKey();
final int registrationId = entry.getValue();
// We know the device must be present because we've already filtered out device IDs that aren't associated
// with the given account
final Device device = account.getDevice(deviceId).orElseThrow();
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
};
return registrationId != expectedRegistrationId;
})
.map(Map.Entry::getKey)
.collect(Collectors.toSet());
return (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty() || !staleDeviceIds.isEmpty())
? Optional.of(new MismatchedDevices(missingDeviceIds, extraDeviceIds, staleDeviceIds))
: Optional.empty();
}
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.push;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.slf4j.Logger;
@@ -54,9 +55,20 @@ public class ReceiptSender {
.setUrgent(false)
.build();
final Map<Byte, Envelope> messagesByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, ignored -> message));
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
}));
try {
messageSender.sendMessages(destinationAccount, destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, ignored -> message)));
messageSender.sendMessages(destinationAccount,
destinationIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId);
} catch (final Exception e) {
logger.warn("Could not send delivery receipt", e);
}