mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-20 17:38:04 +01:00
Add logic to handle sending a common payload to multiple recipients
This commit is contained in:
@@ -20,17 +20,24 @@ import io.lettuce.core.ScriptOutputType;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.Tag;
|
||||
import java.io.IOException;
|
||||
import java.security.MessageDigest;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.validation.Valid;
|
||||
import javax.ws.rs.Consumes;
|
||||
import javax.ws.rs.DELETE;
|
||||
@@ -40,26 +47,33 @@ import javax.ws.rs.PUT;
|
||||
import javax.ws.rs.Path;
|
||||
import javax.ws.rs.PathParam;
|
||||
import javax.ws.rs.Produces;
|
||||
import javax.ws.rs.QueryParam;
|
||||
import javax.ws.rs.WebApplicationException;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.Response;
|
||||
import javax.ws.rs.core.Response.Status;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
|
||||
import org.whispersystems.textsecuregcm.auth.Anonymous;
|
||||
import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys;
|
||||
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
||||
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.StaleDevices;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
|
||||
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
|
||||
@@ -89,6 +103,7 @@ public class MessageController {
|
||||
private final Meter identifiedMeter = metricRegistry.meter(name(getClass(), "delivery", "identified" ));
|
||||
private final Meter rejectOver256kibMessageMeter = metricRegistry.meter(name(getClass(), "rejectOver256kibMessage"));
|
||||
private final Timer sendMessageInternalTimer = metricRegistry.timer(name(getClass(), "sendMessageInternal"));
|
||||
private final Timer sendCommonMessageInternalTimer = metricRegistry.timer(name(getClass(), "sendCommonMessageInternal"));
|
||||
private final Histogram outgoingMessageListSizeHistogram = metricRegistry.histogram(name(getClass(), "outgoingMessageListSize"));
|
||||
|
||||
private final RateLimiters rateLimiters;
|
||||
@@ -295,6 +310,99 @@ public class MessageController {
|
||||
}
|
||||
}
|
||||
|
||||
@Timed
|
||||
@Path("/multi_recipient")
|
||||
@PUT
|
||||
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
|
||||
@Produces(MediaType.APPLICATION_JSON)
|
||||
public Response sendMultiRecipientMessage(
|
||||
@HeaderParam(OptionalAccess.UNIDENTIFIED) CombinedUnidentifiedSenderAccessKeys accessKeys,
|
||||
@HeaderParam("User-Agent") String userAgent,
|
||||
@HeaderParam("X-Forwarded-For") String forwardedFor,
|
||||
@QueryParam("online") boolean online,
|
||||
@QueryParam("ts") long timestamp,
|
||||
@Valid MultiRecipientMessage multiRecipientMessage) {
|
||||
|
||||
unidentifiedMeter.mark(multiRecipientMessage.getRecipients().length);
|
||||
|
||||
Map<UUID, Account> uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients())
|
||||
.map(Recipient::getUuid)
|
||||
.distinct()
|
||||
.collect(Collectors.toMap(Function.identity(), uuid -> {
|
||||
Optional<Account> account = accountsManager.get(uuid);
|
||||
if (account.isEmpty()) {
|
||||
throw new WebApplicationException(Status.NOT_FOUND);
|
||||
}
|
||||
return account.get();
|
||||
}));
|
||||
checkAccessKeys(accessKeys, uuidToAccountMap);
|
||||
|
||||
try {
|
||||
for (Account account : uuidToAccountMap.values()) {
|
||||
Set<Long> deviceIds = Arrays.stream(multiRecipientMessage.getRecipients())
|
||||
.filter(recipient -> recipient.getUuid().equals(account.getUuid()))
|
||||
.map(Recipient::getDeviceId)
|
||||
.collect(Collectors.toSet());
|
||||
validateCompleteDeviceList(account, deviceIds, false);
|
||||
}
|
||||
|
||||
List<Tag> tags = List.of(
|
||||
UserAgentTagUtil.getPlatformTag(userAgent),
|
||||
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
|
||||
Tag.of(SENDER_TYPE_TAG_NAME, "unidentified"));
|
||||
List<UUID> uuids404 = new ArrayList<>();
|
||||
for (Recipient recipient : multiRecipientMessage.getRecipients()) {
|
||||
|
||||
Account destinationAccount = uuidToAccountMap.get(recipient.getUuid());
|
||||
// we asserted this must be true in validateCompleteDeviceList
|
||||
//noinspection OptionalGetWithoutIsPresent
|
||||
Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).get();
|
||||
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
|
||||
try {
|
||||
sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient,
|
||||
multiRecipientMessage.getCommonPayload());
|
||||
} catch (NoSuchUserException e) {
|
||||
uuids404.add(destinationAccount.getUuid());
|
||||
}
|
||||
}
|
||||
return Response.ok(new SendMessageResponse(uuids404)).build();
|
||||
} catch (MismatchedDevicesException e) {
|
||||
throw new WebApplicationException(Response
|
||||
.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))
|
||||
.build());
|
||||
}
|
||||
}
|
||||
|
||||
private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map<UUID, Account> uuidToAccountMap) {
|
||||
AtomicBoolean throwUnauthorized = new AtomicBoolean(false);
|
||||
byte[] empty = new byte[16];
|
||||
byte[] combinedUnknownAccessKeys = uuidToAccountMap.values().stream()
|
||||
.map(Account::getUnidentifiedAccessKey)
|
||||
.map(accessKey -> {
|
||||
if (accessKey.isEmpty()) {
|
||||
throwUnauthorized.set(true);
|
||||
return empty;
|
||||
}
|
||||
return accessKey.get();
|
||||
})
|
||||
.reduce(new byte[16], (bytes, bytes2) -> {
|
||||
if (bytes.length != bytes2.length) {
|
||||
throwUnauthorized.set(true);
|
||||
return bytes;
|
||||
}
|
||||
for (int i = 0; i < bytes.length; i++) {
|
||||
bytes[i] ^= bytes2[i];
|
||||
}
|
||||
return bytes;
|
||||
});
|
||||
if (throwUnauthorized.get()
|
||||
|| !MessageDigest.isEqual(combinedUnknownAccessKeys, accessKeys.getAccessKeys())) {
|
||||
throw new WebApplicationException(Status.UNAUTHORIZED);
|
||||
}
|
||||
}
|
||||
|
||||
private Response declineDelivery(final IncomingMessageList messages, final Account source, final Account destination) {
|
||||
Metrics.counter(DECLINED_DELIVERY_COUNTER, SENDER_COUNTRY_TAG_NAME, Util.getCountryCode(source.getNumber())).increment();
|
||||
|
||||
@@ -464,6 +572,34 @@ public class MessageController {
|
||||
}
|
||||
}
|
||||
|
||||
private void sendMessage(Account destinationAccount, Device destinationDevice, long timestamp, boolean online,
|
||||
Recipient recipient, byte[] commonPayload) throws NoSuchUserException {
|
||||
try (final Timer.Context ignored = sendCommonMessageInternalTimer.time()) {
|
||||
Envelope.Builder messageBuilder = Envelope.newBuilder();
|
||||
long serverTimestamp = System.currentTimeMillis();
|
||||
byte[] recipientKeyMaterial = recipient.getPerRecipientKeyMaterial();
|
||||
|
||||
byte[] payload = new byte[1 + recipientKeyMaterial.length + commonPayload.length];
|
||||
payload[0] = MultiRecipientMessageProvider.VERSION;
|
||||
System.arraycopy(recipientKeyMaterial, 0, payload, 1, recipientKeyMaterial.length);
|
||||
System.arraycopy(commonPayload, 0, payload, 1 + recipientKeyMaterial.length, payload.length);
|
||||
|
||||
messageBuilder
|
||||
.setType(Type.UNIDENTIFIED_SENDER)
|
||||
.setTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
|
||||
.setServerTimestamp(serverTimestamp)
|
||||
.setContent(ByteString.copyFrom(payload));
|
||||
|
||||
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
|
||||
} catch (NotPushRegisteredException e) {
|
||||
if (destinationDevice.isMaster()) {
|
||||
throw new NoSuchUserException(e);
|
||||
} else {
|
||||
logger.debug("Not registered", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void validateRegistrationIds(Account account, List<IncomingMessage> messages)
|
||||
throws StaleDevicesException
|
||||
{
|
||||
@@ -485,22 +621,24 @@ public class MessageController {
|
||||
}
|
||||
}
|
||||
|
||||
private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages, boolean isSyncMessage)
|
||||
throws MismatchedDevicesException {
|
||||
Set<Long> messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet());
|
||||
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage);
|
||||
}
|
||||
|
||||
|
||||
private void validateCompleteDeviceList(Account account,
|
||||
List<IncomingMessage> messages,
|
||||
Set<Long> messageDeviceIds,
|
||||
boolean isSyncMessage)
|
||||
throws MismatchedDevicesException
|
||||
{
|
||||
Set<Long> messageDeviceIds = new HashSet<>();
|
||||
Set<Long> accountDeviceIds = new HashSet<>();
|
||||
|
||||
List<Long> missingDeviceIds = new LinkedList<>();
|
||||
List<Long> extraDeviceIds = new LinkedList<>();
|
||||
|
||||
for (IncomingMessage message : messages) {
|
||||
messageDeviceIds.add(message.getDestinationDeviceId());
|
||||
}
|
||||
|
||||
for (Device device : account.getDevices()) {
|
||||
for (Device device : account.getDevices()) {
|
||||
if (device.isEnabled() &&
|
||||
!(isSyncMessage && device.getId() == account.getAuthenticatedDevice().get().getId()))
|
||||
{
|
||||
@@ -512,9 +650,9 @@ public class MessageController {
|
||||
}
|
||||
}
|
||||
|
||||
for (IncomingMessage message : messages) {
|
||||
if (!accountDeviceIds.contains(message.getDestinationDeviceId())) {
|
||||
extraDeviceIds.add(message.getDestinationDeviceId());
|
||||
for (Long deviceId : messageDeviceIds) {
|
||||
if (!accountDeviceIds.contains(deviceId)) {
|
||||
extraDeviceIds.add(deviceId);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user