Internalize destination device list/registration ID checks in MessageSender

This commit is contained in:
Jon Chambers
2025-04-07 09:15:39 -04:00
committed by GitHub
parent 1d0e2d29a7
commit c6689ca07a
21 changed files with 675 additions and 755 deletions

View File

@@ -43,12 +43,12 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest;
import org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
@@ -93,8 +93,8 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "403", description = "Verification failed for the provided Registration Recovery Password")
@ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
@ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevices.class)))
@ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
@ApiResponse(responseCode = "422", description = "The request did not pass validation")
@ApiResponse(responseCode = "423", content = @Content(schema = @Schema(implementation = RegistrationLockFailure.class)))
@@ -150,16 +150,18 @@ public class AccountControllerV2 {
return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build());
} else {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
e.getMismatchedDevices().extraDeviceIds()))
.build());
}
} catch (IllegalArgumentException e) {
throw new BadRequestException(e);
} catch (MessageTooLargeException e) {
@@ -178,9 +180,9 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "403", description = "This endpoint can only be invoked from the account's primary device.")
@ApiResponse(responseCode = "422", description = "The request body failed validation.")
@ApiResponse(responseCode = "409", description = "The set of devices specified in the request does not match the set of devices active on the account.",
content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(responseCode = "410", description = "The registration IDs provided for some devices do not match those stored on the server.",
content = @Content(schema = @Schema(implementation = StaleDevices.class)))
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
public AccountIdentityResponse distributePhoneNumberIdentityKeys(
@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
@@ -207,16 +209,18 @@ public class AccountControllerV2 {
return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build());
} else {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
e.getMismatchedDevices().extraDeviceIds()))
.build());
}
} catch (IllegalArgumentException e) {
throw new BadRequestException(e);
} catch (MessageTooLargeException e) {

View File

@@ -44,14 +44,12 @@ import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;
import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
@@ -62,7 +60,6 @@ import javax.annotation.Nullable;
import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair;
@@ -81,13 +78,13 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
@@ -113,7 +110,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
@@ -245,10 +241,10 @@ public class MessageController {
description="The message is not a story and some the recipient service ID does not correspond to a registered Signal user")
@ApiResponse(
responseCode = "409", description = "Incorrect set of devices supplied for recipient",
content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
content = @Content(schema = @Schema(implementation = StaleDevices.class)))
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(
responseCode="428",
description="The sender should complete a challenge before proceeding")
@@ -381,14 +377,6 @@ public class MessageController {
rateLimiters.getStoriesLimiter().validate(destination.getUuid());
}
final Set<Byte> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
} else {
excludedDeviceIds = Collections.emptySet();
}
final Map<Byte, Envelope> messagesByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> {
try {
@@ -407,15 +395,8 @@ public class MessageController {
}
}));
DestinationDeviceValidator.validateCompleteDeviceList(destination,
messagesByDeviceId.keySet(),
excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(destination,
messages.messages(),
IncomingMessage::destinationDeviceId,
IncomingMessage::destinationRegistrationId,
destination.getPhoneNumberIdentifier().equals(destinationIdentifier.uuid()));
final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
final String authType;
if (SENDER_TYPE_IDENTIFIED.equals(senderType)) {
@@ -428,7 +409,7 @@ public class MessageController {
authType = AUTH_TYPE_ACCESS_KEY;
}
messageSender.sendMessages(destination, messagesByDeviceId);
messageSender.sendMessages(destination, destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId);
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE),
@@ -440,16 +421,18 @@ public class MessageController {
return Response.ok(new SendMessageResponse(needsSync)).build();
} catch (final MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (final StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build());
} else {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
e.getMismatchedDevices().extraDeviceIds()))
.build());
}
}
} finally {
sample.stop(Timer.builder(SEND_MESSAGE_LATENCY_TIMER_NAME)
@@ -622,57 +605,6 @@ public class MessageController {
}
}
final Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
final Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
if (!resolvedRecipients.containsKey(recipient)) {
// When sending stories, we might not be able to resolve all recipients to existing accounts. That's okay! We
// can just skip them.
return;
}
final Account account = resolvedRecipients.get(recipient);
try {
final Map<Byte, Short> deviceIdsToRegistrationIds = recipient.getDevicesAndRegistrationIds()
.collect(Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second));
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIdsToRegistrationIds.keySet(),
Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds(
account,
deviceIdsToRegistrationIds.entrySet(),
Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()),
serviceId instanceof ServiceId.Pni);
} catch (final MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
ServiceIdentifier.fromLibsignal(serviceId),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (final StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(ServiceIdentifier.fromLibsignal(serviceId), new StaleDevices(e.getStaleDevices())));
}
});
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(accountMismatchedDevices)
.build();
}
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
final String authType;
if (isStory) {
authType = AUTH_TYPE_STORY;
@@ -731,6 +663,38 @@ public class MessageController {
} catch (ExecutionException e) {
logger.error("partial failure while delivering multi-recipient messages", e.getCause());
throw new InternalServerErrorException("failure during delivery");
} catch (MultiRecipientMismatchedDevicesException e) {
final List<AccountMismatchedDevices> accountMismatchedDevices =
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
.filter(entry -> !entry.getValue().missingDeviceIds().isEmpty() || !entry.getValue().extraDeviceIds().isEmpty())
.map(entry -> new AccountMismatchedDevices(entry.getKey(),
new MismatchedDevicesResponse(entry.getValue().missingDeviceIds(), entry.getValue().extraDeviceIds())))
.toList();
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(accountMismatchedDevices)
.build();
}
final List<AccountStaleDevices> accountStaleDevices =
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
.filter(entry -> !entry.getValue().staleDeviceIds().isEmpty())
.map(entry -> new AccountStaleDevices(entry.getKey(),
new StaleDevicesResponse(entry.getValue().staleDeviceIds())))
.toList();
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
throw new RuntimeException(e);
}
}

View File

@@ -0,0 +1,11 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.util.Set;
public record MismatchedDevices(Set<Byte> missingDeviceIds, Set<Byte> extraDeviceIds, Set<Byte> staleDeviceIds) {
}

View File

@@ -5,23 +5,15 @@
package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class MismatchedDevicesException extends Exception {
private final List<Byte> missingDevices;
private final List<Byte> extraDevices;
private final MismatchedDevices mismatchedDevices;
public MismatchedDevicesException(List<Byte> missingDevices, List<Byte> extraDevices) {
this.missingDevices = missingDevices;
this.extraDevices = extraDevices;
public MismatchedDevicesException(final MismatchedDevices mismatchedDevices) {
this.mismatchedDevices = mismatchedDevices;
}
public List<Byte> getMissingDevices() {
return missingDevices;
}
public List<Byte> getExtraDevices() {
return extraDevices;
public MismatchedDevices getMismatchedDevices() {
return mismatchedDevices;
}
}

View File

@@ -0,0 +1,24 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import java.util.Map;
public class MultiRecipientMismatchedDevicesException extends Exception {
private final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier;
public MultiRecipientMismatchedDevicesException(
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier) {
this.mismatchedDevicesByServiceIdentifier = mismatchedDevicesByServiceIdentifier;
}
public Map<ServiceIdentifier, MismatchedDevices> getMismatchedDevicesByServiceIdentifier() {
return mismatchedDevicesByServiceIdentifier;
}
}

View File

@@ -1,22 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class StaleDevicesException extends Exception {
private final List<Byte> staleDevices;
public StaleDevicesException(List<Byte> staleDevices) {
this.staleDevices = staleDevices;
}
public List<Byte> getStaleDevices() {
return staleDevices;
}
}

View File

@@ -14,5 +14,5 @@ public record AccountMismatchedDevices(@JsonSerialize(using = ServiceIdentifierA
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
ServiceIdentifier uuid,
MismatchedDevices devices) {
MismatchedDevicesResponse devices) {
}

View File

@@ -14,5 +14,5 @@ public record AccountStaleDevices(@JsonSerialize(using = ServiceIdentifierAdapte
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
ServiceIdentifier uuid,
StaleDevices devices) {
StaleDevicesResponse devices) {
}

View File

@@ -1,20 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
public record MismatchedDevices(@JsonProperty
@Schema(description = "Devices present on the account but absent in the request")
List<Byte> missingDevices,
@JsonProperty
@Schema(description = "Devices absent on the request but present in the account")
List<Byte> extraDevices) {
}

View File

@@ -0,0 +1,20 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.Set;
public record MismatchedDevicesResponse(@JsonProperty
@Schema(description = "Devices present on the account but absent in the request")
Set<Byte> missingDevices,
@JsonProperty
@Schema(description = "Devices absent on the request but present in the account")
Set<Byte> extraDevices) {
}

View File

@@ -8,9 +8,9 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
import java.util.Set;
public record StaleDevices(@JsonProperty
@Schema(description = "Devices that are no longer active")
List<Byte> staleDevices) {
public record StaleDevicesResponse(@JsonProperty
@Schema(description = "Devices that are no longer active")
Set<Byte> staleDevices) {
}

View File

@@ -11,13 +11,24 @@ import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.util.DataSize;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -58,6 +69,9 @@ public class MessageSender {
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes();
@VisibleForTesting
static final byte NO_EXCLUDED_DEVICE_ID = -1;
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager;
@@ -68,23 +82,51 @@ public class MessageSender {
* notification token and does not have an active connection to a Signal server, then this method will also send a
* push notification to that device to announce the availability of new messages.
*
* @param account the account to which to send messages
* @param destination the account to which to send messages
* @param destinationIdentifier the service identifier to which the messages are addressed
* @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
*/
public void sendMessages(final Account account, final Map<Byte, Envelope> messagesByDeviceId) {
messagesManager.insert(account.getIdentifier(IdentityType.ACI), messagesByDeviceId)
public void sendMessages(final Account destination,
final ServiceIdentifier destinationIdentifier,
final Map<Byte, Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId) throws MismatchedDevicesException {
if (messagesByDeviceId.isEmpty()) {
return;
}
if (!destination.isIdentifiedBy(destinationIdentifier)) {
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
}
final Envelope firstMessage = messagesByDeviceId.values().iterator().next();
final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId()));
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
destinationIdentifier,
registrationIdsByDeviceId,
isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID);
if (maybeMismatchedDevices.isPresent()) {
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
}
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
.forEach((deviceId, destinationPresent) -> {
final Envelope message = messagesByDeviceId.get(deviceId);
if (!destinationPresent && !message.getEphemeral()) {
try {
pushNotificationManager.sendNewMessageNotification(account, deviceId, message.getUrgent());
pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent());
} catch (final NotPushRegisteredException ignored) {
}
}
Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME, account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
CHANNEL_TAG_NAME, destination.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()),
CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
@@ -98,6 +140,10 @@ public class MessageSender {
* Sends messages to a group of recipients. If a destination device has a valid push notification token and does not
* have an active connection to a Signal server, then this method will also send a push notification to that device to
* announce the availability of new messages.
* <p>
* This method sends messages to all <em>resolved</em> recipients. In some cases, a caller may not be able to resolve
* all recipients to active accounts, but may still choose to send the message. Callers are responsible for rejecting
* the message if they require full resolution of all recipients, but some recipients could not be resolved.
*
* @param multiRecipientMessage the multi-recipient message to send to the given recipients
* @param resolvedRecipients a map of recipients to resolved Signal accounts
@@ -114,7 +160,31 @@ public class MessageSender {
final long clientTimestamp,
final boolean isStory,
final boolean isEphemeral,
final boolean isUrgent) {
final boolean isUrgent) throws MultiRecipientMismatchedDevicesException {
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier = new HashMap<>();
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
if (!resolvedRecipients.containsKey(recipient)) {
// Callers are responsible for rejecting messages if they're missing recipients in a problematic way. If we run
// into an unresolved recipient here, just skip it.
return;
}
final Account account = resolvedRecipients.get(recipient);
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromLibsignal(serviceId);
final Map<Byte, Integer> registrationIdsByDeviceId = recipient.getDevicesAndRegistrationIds()
.collect(Collectors.toMap(Pair::first, pair -> (int) pair.second()));
getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, NO_EXCLUDED_DEVICE_ID)
.ifPresent(mismatchedDevices ->
mismatchedDevicesByServiceIdentifier.put(serviceIdentifier, mismatchedDevices));
});
if (!mismatchedDevicesByServiceIdentifier.isEmpty()) {
throw new MultiRecipientMismatchedDevicesException(mismatchedDevicesByServiceIdentifier);
}
return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp,
isStory, isEphemeral, isUrgent)
@@ -189,4 +259,47 @@ public class MessageSender {
.increment();
}
}
@VisibleForTesting
static Optional<MismatchedDevices> getMismatchedDevices(final Account account,
final ServiceIdentifier serviceIdentifier,
final Map<Byte, Integer> registrationIdsByDeviceId,
final byte excludedDeviceId) {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != excludedDeviceId)
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(registrationIdsByDeviceId.keySet());
final Set<Byte> extraDeviceIds = new HashSet<>(registrationIdsByDeviceId.keySet());
extraDeviceIds.removeAll(accountDeviceIds);
final Set<Byte> staleDeviceIds = registrationIdsByDeviceId.entrySet().stream()
// Filter out device IDs that aren't associated with the given account
.filter(entry -> !extraDeviceIds.contains(entry.getKey()))
.filter(entry -> {
final byte deviceId = entry.getKey();
final int registrationId = entry.getValue();
// We know the device must be present because we've already filtered out device IDs that aren't associated
// with the given account
final Device device = account.getDevice(deviceId).orElseThrow();
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
};
return registrationId != expectedRegistrationId;
})
.map(Map.Entry::getKey)
.collect(Collectors.toSet());
return (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty() || !staleDeviceIds.isEmpty())
? Optional.of(new MismatchedDevices(missingDeviceIds, extraDeviceIds, staleDeviceIds))
: Optional.empty();
}
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.push;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.slf4j.Logger;
@@ -54,9 +55,20 @@ public class ReceiptSender {
.setUrgent(false)
.build();
final Map<Byte, Envelope> messagesByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, ignored -> message));
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
}));
try {
messageSender.sendMessages(destinationAccount, destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, ignored -> message)));
messageSender.sendMessages(destinationAccount,
destinationIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId);
} catch (final Exception e) {
logger.warn("Could not send delivery receipt", e);
}

View File

@@ -37,10 +37,12 @@ import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
@@ -55,6 +57,7 @@ import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.crypto.Mac;
@@ -67,6 +70,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
@@ -83,7 +87,6 @@ import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@@ -780,24 +783,33 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
// Check that all including primary ID are in signed pre-keys
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniSignedPreKeys.keySet(),
Collections.emptySet());
validateCompleteDeviceList(account, pniSignedPreKeys.keySet());
// Check that all including primary ID are in Pq pre-keys
if (pniPqLastResortPreKeys != null) {
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniPqLastResortPreKeys.keySet(),
Collections.emptySet());
validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet());
}
// Check that all devices are accounted for in the map of new PNI registration IDs
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniRegistrationIds.keySet(),
Collections.emptySet());
validateCompleteDeviceList(account, pniRegistrationIds.keySet());
}
@VisibleForTesting
static void validateCompleteDeviceList(final Account account, final Set<Byte> deviceIds) throws MismatchedDevicesException {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(deviceIds);
final Set<Byte> extraDeviceIds = new HashSet<>(deviceIds);
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(
new MismatchedDevices(missingDeviceIds, extraDeviceIds, Collections.emptySet()));
}
}
public record UsernameReservation(Account account, byte[] reservedUsernameHash){}

View File

@@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils;
@@ -15,15 +14,15 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
public class ChangeNumberManager {
@@ -45,12 +44,10 @@ public class ChangeNumberManager {
@Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Byte, Integer> pniRegistrationIds,
@Nullable final String senderUserAgent)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException, MessageTooLargeException {
throws InterruptedException, MismatchedDevicesException, MessageTooLargeException {
if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
// AccountsManager validates the device set on deviceSignedPreKeys and pniRegistrationIds
validateDeviceMessages(account, deviceMessages);
} else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) ||
ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds))) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null");
}
@@ -84,9 +81,7 @@ public class ChangeNumberManager {
@Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages,
final Map<Byte, Integer> pniRegistrationIds,
final String senderUserAgent) throws MismatchedDevicesException, StaleDevicesException, MessageTooLargeException {
validateDeviceMessages(account, deviceMessages);
final String senderUserAgent) throws MismatchedDevicesException, MessageTooLargeException {
// Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb
// write anyway. Linked devices can handle some wasted extra key rotations.
@@ -97,26 +92,9 @@ public class ChangeNumberManager {
return updatedAccount;
}
private void validateDeviceMessages(final Account account,
final List<IncomingMessage> deviceMessages) throws MismatchedDevicesException, StaleDevicesException {
// Check that all except primary ID are in device messages
DestinationDeviceValidator.validateCompleteDeviceList(
account,
deviceMessages.stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()),
Set.of(Device.PRIMARY_ID));
// check that all sync messages are to the current registration ID for the matching device
DestinationDeviceValidator.validateRegistrationIds(
account,
deviceMessages,
IncomingMessage::destinationDeviceId,
IncomingMessage::destinationRegistrationId,
false);
}
private void sendDeviceMessages(final Account account,
final List<IncomingMessage> deviceMessages,
final String senderUserAgent) throws MessageTooLargeException {
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
for (final IncomingMessage message : deviceMessages) {
MessageSender.validateContentLength(message.content().length,
@@ -128,20 +106,26 @@ public class ChangeNumberManager {
try {
final long serverTimestamp = System.currentTimeMillis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
messageSender.sendMessages(account, deviceMessages.stream()
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder()
.setType(Envelope.Type.forNumber(message.type()))
.setClientTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setContent(ByteString.copyFrom(message.content()))
.setSourceServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(account.getPhoneNumberIdentifier().toString())
.setUrgent(true)
.setEphemeral(false)
.build())));
.build()));
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId));
messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId);
} catch (final RuntimeException e) {
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
throw e;

View File

@@ -1,108 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
public class DestinationDeviceValidator {
/**
* @see #validateRegistrationIds(Account, Stream, boolean)
*/
public static <T> void validateRegistrationIds(final Account account,
final Collection<T> messages,
Function<T, Byte> getDeviceId,
Function<T, Integer> getRegistrationId,
boolean usePhoneNumberIdentity) throws StaleDevicesException {
validateRegistrationIds(account,
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
usePhoneNumberIdentity);
}
/**
* Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID
* pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the
* destination account are present in the given device ID/registration ID pairs.
*
* @param account the destination account against which to check the given device
* ID/registration ID pairs
* @param deviceIdAndRegistrationIdStream a stream of device ID and registration ID pairs
* @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device
* registration IDs associated with the account's PNI (if available); compare
* against the ACI-associated registration ID otherwise
* @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination
* account does not have a corresponding device or if the registration IDs do not match
*/
public static void validateRegistrationIds(final Account account,
final Stream<Pair<Byte, Integer>> deviceIdAndRegistrationIdStream,
final boolean usePhoneNumberIdentity) throws StaleDevicesException {
final List<Byte> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
final byte deviceId = deviceIdAndRegistrationId.first();
final int registrationId = deviceIdAndRegistrationId.second();
boolean registrationIdMatches = account.getDevice(deviceId)
.map(device -> registrationId == (usePhoneNumberIdentity
? device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId())
: device.getRegistrationId()))
.orElse(false);
return !registrationIdMatches;
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
/**
* Validates that the given set of device IDs from a set of messages matches the set of device IDs associated with the
* given destination account in preparation for sending those messages to the destination account. In general, the set
* of device IDs must exactly match the set of active devices associated with the destination account. When sending a
* "sync," message, though, the authenticated account is sending messages from one of their devices to all other
* devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}.
*
* @param account the destination account against which to check the given set of device IDs
* @param messageDeviceIds the set of device IDs to check against the destination account
* @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be
* present in the given set of device IDs (i.e. the device that is sending a sync message)
* @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with
* the destination account or is missing entries associated with the destination
* account
*/
public static void validateCompleteDeviceList(final Account account,
final Set<Byte> messageDeviceIds,
final Set<Byte> excludedDeviceIds) throws MismatchedDevicesException {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> !excludedDeviceIds.contains(deviceId))
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(messageDeviceIds);
final Set<Byte> extraDeviceIds = new HashSet<>(messageDeviceIds);
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(new ArrayList<>(missingDeviceIds), new ArrayList<>(extraDeviceIds));
}
}
}