Internalize destination device list/registration ID checks in MessageSender

This commit is contained in:
Jon Chambers
2025-04-07 09:15:39 -04:00
committed by GitHub
parent 1d0e2d29a7
commit c6689ca07a
21 changed files with 675 additions and 755 deletions

View File

@@ -43,12 +43,12 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest;
import org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
@@ -93,8 +93,8 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "403", description = "Verification failed for the provided Registration Recovery Password")
@ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
@ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevices.class)))
@ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
@ApiResponse(responseCode = "422", description = "The request did not pass validation")
@ApiResponse(responseCode = "423", content = @Content(schema = @Schema(implementation = RegistrationLockFailure.class)))
@@ -150,16 +150,18 @@ public class AccountControllerV2 {
return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build());
} else {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
e.getMismatchedDevices().extraDeviceIds()))
.build());
}
} catch (IllegalArgumentException e) {
throw new BadRequestException(e);
} catch (MessageTooLargeException e) {
@@ -178,9 +180,9 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "403", description = "This endpoint can only be invoked from the account's primary device.")
@ApiResponse(responseCode = "422", description = "The request body failed validation.")
@ApiResponse(responseCode = "409", description = "The set of devices specified in the request does not match the set of devices active on the account.",
content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(responseCode = "410", description = "The registration IDs provided for some devices do not match those stored on the server.",
content = @Content(schema = @Schema(implementation = StaleDevices.class)))
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
public AccountIdentityResponse distributePhoneNumberIdentityKeys(
@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
@@ -207,16 +209,18 @@ public class AccountControllerV2 {
return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build());
} else {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
e.getMismatchedDevices().extraDeviceIds()))
.build());
}
} catch (IllegalArgumentException e) {
throw new BadRequestException(e);
} catch (MessageTooLargeException e) {

View File

@@ -44,14 +44,12 @@ import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;
import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
@@ -62,7 +60,6 @@ import javax.annotation.Nullable;
import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair;
@@ -81,13 +78,13 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
@@ -113,7 +110,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
@@ -245,10 +241,10 @@ public class MessageController {
description="The message is not a story and some the recipient service ID does not correspond to a registered Signal user")
@ApiResponse(
responseCode = "409", description = "Incorrect set of devices supplied for recipient",
content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
content = @Content(schema = @Schema(implementation = StaleDevices.class)))
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(
responseCode="428",
description="The sender should complete a challenge before proceeding")
@@ -381,14 +377,6 @@ public class MessageController {
rateLimiters.getStoriesLimiter().validate(destination.getUuid());
}
final Set<Byte> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
} else {
excludedDeviceIds = Collections.emptySet();
}
final Map<Byte, Envelope> messagesByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> {
try {
@@ -407,15 +395,8 @@ public class MessageController {
}
}));
DestinationDeviceValidator.validateCompleteDeviceList(destination,
messagesByDeviceId.keySet(),
excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(destination,
messages.messages(),
IncomingMessage::destinationDeviceId,
IncomingMessage::destinationRegistrationId,
destination.getPhoneNumberIdentifier().equals(destinationIdentifier.uuid()));
final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
final String authType;
if (SENDER_TYPE_IDENTIFIED.equals(senderType)) {
@@ -428,7 +409,7 @@ public class MessageController {
authType = AUTH_TYPE_ACCESS_KEY;
}
messageSender.sendMessages(destination, messagesByDeviceId);
messageSender.sendMessages(destination, destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId);
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE),
@@ -440,16 +421,18 @@ public class MessageController {
return Response.ok(new SendMessageResponse(needsSync)).build();
} catch (final MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (final StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build());
} else {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
e.getMismatchedDevices().extraDeviceIds()))
.build());
}
}
} finally {
sample.stop(Timer.builder(SEND_MESSAGE_LATENCY_TIMER_NAME)
@@ -622,57 +605,6 @@ public class MessageController {
}
}
final Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
final Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
if (!resolvedRecipients.containsKey(recipient)) {
// When sending stories, we might not be able to resolve all recipients to existing accounts. That's okay! We
// can just skip them.
return;
}
final Account account = resolvedRecipients.get(recipient);
try {
final Map<Byte, Short> deviceIdsToRegistrationIds = recipient.getDevicesAndRegistrationIds()
.collect(Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second));
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIdsToRegistrationIds.keySet(),
Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds(
account,
deviceIdsToRegistrationIds.entrySet(),
Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()),
serviceId instanceof ServiceId.Pni);
} catch (final MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
ServiceIdentifier.fromLibsignal(serviceId),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (final StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(ServiceIdentifier.fromLibsignal(serviceId), new StaleDevices(e.getStaleDevices())));
}
});
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(accountMismatchedDevices)
.build();
}
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
final String authType;
if (isStory) {
authType = AUTH_TYPE_STORY;
@@ -731,6 +663,38 @@ public class MessageController {
} catch (ExecutionException e) {
logger.error("partial failure while delivering multi-recipient messages", e.getCause());
throw new InternalServerErrorException("failure during delivery");
} catch (MultiRecipientMismatchedDevicesException e) {
final List<AccountMismatchedDevices> accountMismatchedDevices =
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
.filter(entry -> !entry.getValue().missingDeviceIds().isEmpty() || !entry.getValue().extraDeviceIds().isEmpty())
.map(entry -> new AccountMismatchedDevices(entry.getKey(),
new MismatchedDevicesResponse(entry.getValue().missingDeviceIds(), entry.getValue().extraDeviceIds())))
.toList();
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(accountMismatchedDevices)
.build();
}
final List<AccountStaleDevices> accountStaleDevices =
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
.filter(entry -> !entry.getValue().staleDeviceIds().isEmpty())
.map(entry -> new AccountStaleDevices(entry.getKey(),
new StaleDevicesResponse(entry.getValue().staleDeviceIds())))
.toList();
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
throw new RuntimeException(e);
}
}

View File

@@ -0,0 +1,11 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.util.Set;
public record MismatchedDevices(Set<Byte> missingDeviceIds, Set<Byte> extraDeviceIds, Set<Byte> staleDeviceIds) {
}

View File

@@ -5,23 +5,15 @@
package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class MismatchedDevicesException extends Exception {
private final List<Byte> missingDevices;
private final List<Byte> extraDevices;
private final MismatchedDevices mismatchedDevices;
public MismatchedDevicesException(List<Byte> missingDevices, List<Byte> extraDevices) {
this.missingDevices = missingDevices;
this.extraDevices = extraDevices;
public MismatchedDevicesException(final MismatchedDevices mismatchedDevices) {
this.mismatchedDevices = mismatchedDevices;
}
public List<Byte> getMissingDevices() {
return missingDevices;
}
public List<Byte> getExtraDevices() {
return extraDevices;
public MismatchedDevices getMismatchedDevices() {
return mismatchedDevices;
}
}

View File

@@ -0,0 +1,24 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import java.util.Map;
public class MultiRecipientMismatchedDevicesException extends Exception {
private final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier;
public MultiRecipientMismatchedDevicesException(
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier) {
this.mismatchedDevicesByServiceIdentifier = mismatchedDevicesByServiceIdentifier;
}
public Map<ServiceIdentifier, MismatchedDevices> getMismatchedDevicesByServiceIdentifier() {
return mismatchedDevicesByServiceIdentifier;
}
}

View File

@@ -1,22 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class StaleDevicesException extends Exception {
private final List<Byte> staleDevices;
public StaleDevicesException(List<Byte> staleDevices) {
this.staleDevices = staleDevices;
}
public List<Byte> getStaleDevices() {
return staleDevices;
}
}