Add handling of registration id in multi recipient send payload

This commit is contained in:
Ehren Kret
2021-05-17 13:01:37 -05:00
parent 10cd60738a
commit f76e6705c0
4 changed files with 86 additions and 15 deletions

View File

@@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
@@ -62,6 +63,7 @@ import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKey
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
@@ -94,6 +96,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
@@ -360,16 +363,23 @@ public class MessageController {
checkAccessKeys(accessKeys, uuidToAccountMap);
List<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
List<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
for (Account account : uuidToAccountMap.values()) {
Set<Long> deviceIds = Arrays.stream(multiRecipientMessage.getRecipients())
.filter(recipient -> recipient.getUuid().equals(account.getUuid()))
.map(Recipient::getDeviceId)
.collect(Collectors.toSet());
Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = Arrays.stream(multiRecipientMessage.getRecipients())
.filter(recipient -> recipient.getUuid().equals(account.getUuid()))
.map(recipient -> new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()));
try {
validateCompleteDeviceList(account, deviceIds, false);
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(new AccountStaleDevices(account.getUuid(), new StaleDevices(e.getStaleDevices())));
}
}
if (!accountMismatchedDevices.isEmpty()) {
@@ -379,6 +389,13 @@ public class MessageController {
.entity(accountMismatchedDevices)
.build();
}
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
List<Tag> tags = List.of(
UserAgentTagUtil.getPlatformTag(userAgent),
@@ -639,20 +656,23 @@ public class MessageController {
}
private void validateRegistrationIds(Account account, List<IncomingMessage> messages)
throws StaleDevicesException
{
List<Long> staleDevices = new LinkedList<>();
throws StaleDevicesException {
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = messages
.stream()
.map(message -> new Pair<>(message.getDestinationDeviceId(), message.getDestinationRegistrationId()));
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
}
for (IncomingMessage message : messages) {
Optional<Device> device = account.getDevice(message.getDestinationDeviceId());
if (device.isPresent() &&
message.getDestinationRegistrationId() > 0 &&
message.getDestinationRegistrationId() != device.get().getRegistrationId())
{
staleDevices.add(device.get().getId());
}
}
private void validateRegistrationIds(Account account, Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream)
throws StaleDevicesException {
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
Optional<Device> device = account.getDevice(deviceIdAndRegistrationId.first());
return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId();
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);