Replace MultiRecipientMessage parsing with libsignal's implementation

Co-authored-by: Jonathan Klabunde Tomer <jkt@signal.org>
This commit is contained in:
Jordan Rose
2023-12-08 08:52:47 -08:00
committed by GitHub
parent f20d3043d6
commit 2ab3c97ee8
8 changed files with 64 additions and 337 deletions

View File

@@ -25,13 +25,9 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.security.MessageDigest;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -42,9 +38,6 @@ import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -73,6 +66,9 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient;
import org.signal.libsignal.protocol.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.Anonymous;
@@ -88,8 +84,6 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
@@ -118,13 +112,13 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.Stories;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.util.function.Tuples;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/messages")
@@ -134,7 +128,8 @@ public class MessageController {
private record MultiRecipientDeliveryData(
ServiceIdentifier serviceIdentifier,
Account account,
Map<Byte, Recipient> perDeviceData) {
Recipient recipient,
Map<Byte, Short> deviceIdToRegistrationId) {
}
private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
@@ -369,20 +364,22 @@ public class MessageController {
* Build mapping of service IDs to resolved accounts and device/registration IDs
*/
private Map<ServiceIdentifier, MultiRecipientDeliveryData> buildRecipientMap(
MultiRecipientMessage multiRecipientMessage, boolean isStory) {
return Flux.fromArray(multiRecipientMessage.recipients())
.groupBy(Recipient::uuid, multiRecipientMessage.recipients().length)
SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) {
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue()))
.flatMap(
gf -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(gf.key()))
t -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(t.getT1()))
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
.flatMap(
.map(
account ->
gf.collectMap(Recipient::deviceId)
.map(perRecipientData ->
new MultiRecipientDeliveryData(
gf.key(),
account,
perRecipientData))))
new MultiRecipientDeliveryData(
t.getT1(),
account,
t.getT2(),
t.getT2().getDevicesAndRegistrationIds().collect(
Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second))))
// IllegalStateException is thrown by Collectors#toMap when we have multiple entries for the same device
.onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e))
.collectMap(MultiRecipientDeliveryData::serviceIdentifier)
.block();
}
@@ -429,8 +426,8 @@ public class MessageController {
@Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted")
@QueryParam("story") boolean isStory,
@Parameter(description="The sealed-sender multi-recipient message payload")
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
@Parameter(description="The sealed-sender multi-recipient message payload as serialized by libsignal")
@NotNull SealedSenderMultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
final Map<ServiceIdentifier, MultiRecipientDeliveryData> recipients = buildRecipientMap(multiRecipientMessage, isStory);
@@ -456,13 +453,13 @@ public class MessageController {
final Account account = recipient.account();
try {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.perDeviceData().keySet(), Collections.emptySet());
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds(
account,
recipient.perDeviceData().values(),
Recipient::deviceId,
Recipient::registrationId,
recipient.deviceIdToRegistrationId().entrySet(),
Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()),
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(
@@ -500,17 +497,19 @@ public class MessageController {
CompletableFuture.allOf(
recipients.values().stream()
.flatMap(recipientData ->
recipientData.perDeviceData().values().stream().map(
recipient -> CompletableFuture.runAsync(
recipientData.deviceIdToRegistrationId().keySet().stream().map(
deviceId ->CompletableFuture.runAsync(
() -> {
final Account destinationAccount = recipientData.account();
final byte[] payload = multiRecipientMessage.messageForRecipient(recipientData.recipient());
// we asserted this must exist in validateCompleteDeviceList
final Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
final Device destinationDevice = destinationAccount.getDevice(deviceId).orElseThrow();
try {
sentMessageCounter.increment();
sendCommonPayloadMessage(
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online,
isStory, isUrgent, recipient, multiRecipientMessage.commonPayload());
isStory, isUrgent, payload);
} catch (NoSuchUserException e) {
// this should never happen, because we already asserted the device is present and enabled
Metrics.counter(
@@ -740,17 +739,10 @@ public class MessageController {
boolean online,
boolean story,
boolean urgent,
Recipient recipient,
byte[] commonPayload) throws NoSuchUserException {
byte[] payload) throws NoSuchUserException {
try {
Envelope.Builder messageBuilder = Envelope.newBuilder();
long serverTimestamp = System.currentTimeMillis();
byte[] recipientKeyMaterial = recipient.perRecipientKeyMaterial();
byte[] payload = new byte[1 + recipientKeyMaterial.length + commonPayload.length];
payload[0] = MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER;
System.arraycopy(recipientKeyMaterial, 0, payload, 1, recipientKeyMaterial.length);
System.arraycopy(commonPayload, 0, payload, 1 + recipientKeyMaterial.length, commonPayload.length);
messageBuilder
.setType(Type.UNIDENTIFIED_SENDER)