Add support for setting PNI-associated registration IDs and identity keys when changing numbers

This commit is contained in:
Jon Chambers
2022-07-26 15:19:27 -04:00
committed by GitHub
parent c252118cfc
commit dce391a248
26 changed files with 927 additions and 673 deletions

View File

@@ -68,13 +68,11 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest;
import org.whispersystems.textsecuregcm.entities.DeviceName;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnMessage;
@@ -95,7 +93,6 @@ import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException;
import org.whispersystems.textsecuregcm.util.Username;
import org.whispersystems.textsecuregcm.util.Util;
@@ -416,41 +413,9 @@ public class AccountController {
throw new ForbiddenException();
}
if (request.getDeviceSignedPrekeys() != null && !request.getDeviceSignedPrekeys().isEmpty()) {
if (request.getDeviceMessages() == null || request.getDeviceMessages().size() != request.getDeviceSignedPrekeys().size() - 1) {
// device_messages should exist and be one shorter than device_signed_prekeys, since it doesn't have the primary's key.
throw new WebApplicationException(Response.status(400).build());
}
try {
// Checks that all except master ID are in device messages
MessageValidation.validateCompleteDeviceList(
authenticatedAccount.getAccount(), request.getDeviceMessages(),
IncomingMessage::getDestinationDeviceId, true, Optional.of(Device.MASTER_ID));
MessageValidation.validateRegistrationIds(
authenticatedAccount.getAccount(), request.getDeviceMessages(),
IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId);
// Checks that all including master ID are in signed prekeys
MessageValidation.validateCompleteDeviceList(
authenticatedAccount.getAccount(), request.getDeviceSignedPrekeys().entrySet(),
e -> e.getKey(), false, Optional.empty());
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
}
} else if (request.getDeviceMessages() != null && !request.getDeviceMessages().isEmpty()) {
// device_messages shouldn't exist without device_signed_prekeys.
throw new WebApplicationException(Response.status(400).build());
}
final String number = request.number();
final String number = request.getNumber();
// Only "bill" for rate limiting if we think there's a change to be made...
if (!authenticatedAccount.getAccount().getNumber().equals(number)) {
Util.requireNormalizedNumber(number);
@@ -459,7 +424,7 @@ public class AccountController {
final Optional<StoredVerificationCode> storedVerificationCode =
pendingAccounts.getCodeForNumber(number);
if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(request.getCode())) {
if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(request.code())) {
throw new ForbiddenException();
}
@@ -469,24 +434,42 @@ public class AccountController {
final Optional<Account> existingAccount = accounts.getByE164(number);
if (existingAccount.isPresent()) {
verifyRegistrationLock(existingAccount.get(), request.getRegistrationLock());
verifyRegistrationLock(existingAccount.get(), request.registrationLock());
}
rateLimiters.getVerifyLimiter().clear(number);
}
final Account updatedAccount = changeNumberManager.changeNumber(
authenticatedAccount.getAccount(),
request.getNumber(),
Optional.ofNullable(request.getDeviceSignedPrekeys()).orElse(Collections.emptyMap()),
Optional.ofNullable(request.getDeviceMessages()).orElse(Collections.emptyList()));
// ...but always attempt to make the change in case a client retries and needs to re-send messages
try {
final Account updatedAccount = changeNumberManager.changeNumber(
authenticatedAccount.getAccount(),
request.number(),
request.pniIdentityKey(),
Optional.ofNullable(request.devicePniSignedPrekeys()).orElse(Collections.emptyMap()),
Optional.ofNullable(request.deviceMessages()).orElse(Collections.emptyList()),
Optional.ofNullable(request.pniRegistrationIds()).orElse(Collections.emptyMap()));
return new AccountIdentityResponse(
updatedAccount.getUuid(),
updatedAccount.getNumber(),
updatedAccount.getPhoneNumberIdentifier(),
updatedAccount.getUsername().orElse(null),
updatedAccount.isStorageSupported());
return new AccountIdentityResponse(
updatedAccount.getUuid(),
updatedAccount.getNumber(),
updatedAccount.getPhoneNumberIdentifier(),
updatedAccount.getUsername().orElse(null),
updatedAccount.isStorageSupported());
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
} catch (IllegalArgumentException e) {
throw new BadRequestException(e);
}
}
@Timed
@@ -625,6 +608,7 @@ public class AccountController {
d.setLastSeen(Util.todayInMillis());
d.setCapabilities(attributes.getCapabilities());
d.setRegistrationId(attributes.getRegistrationId());
attributes.getPhoneNumberIdentityRegistrationId().ifPresent(d::setPhoneNumberIdentityRegistrationId);
d.setUserAgent(userAgent);
});

View File

@@ -198,6 +198,7 @@ public class DeviceController {
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
accountAttributes.getPhoneNumberIdentityRegistrationId().ifPresent(device::setPhoneNumberIdentityRegistrationId);
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());

View File

@@ -197,7 +197,11 @@ public class KeysController {
PreKey preKey = preKeysByDeviceId.get(device.getId());
if (signedPreKey != null || preKey != null) {
responseItems.add(new PreKeyResponseItem(device.getId(), device.getRegistrationId(), signedPreKey, preKey));
final int registrationId = usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId();
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedPreKey, preKey));
}
}
}

View File

@@ -21,7 +21,7 @@ import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -89,8 +89,7 @@ import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
@@ -214,11 +213,23 @@ public class MessageController {
checkRateLimit(source.get(), destination.get(), userAgent);
}
MessageValidation.validateCompleteDeviceList(destination.get(), messages.getMessages(),
IncomingMessage::getDestinationDeviceId, isSyncMessage,
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
MessageValidation.validateRegistrationIds(destination.get(), messages.getMessages(),
IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId);
final Set<Long> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
} else {
excludedDeviceIds = Collections.emptySet();
}
DestinationDeviceValidator.validateCompleteDeviceList(destination.get(),
messages.getMessages().stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()),
excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(destination.get(),
messages.getMessages().stream().collect(Collectors.toMap(
IncomingMessage::getDestinationDeviceId,
IncomingMessage::getDestinationRegistrationId)),
destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())),
@@ -307,13 +318,25 @@ public class MessageController {
checkRateLimit(source.get(), destination.get(), userAgent);
}
final List<IncomingDeviceMessage> messagesAsList = Arrays.asList(messages);
MessageValidation.validateCompleteDeviceList(destination.get(), messagesAsList,
IncomingDeviceMessage::getDeviceId, isSyncMessage,
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
MessageValidation.validateRegistrationIds(destination.get(), messagesAsList,
IncomingDeviceMessage::getDeviceId,
IncomingDeviceMessage::getRegistrationId);
final Set<Long> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
} else {
excludedDeviceIds = Collections.emptySet();
}
DestinationDeviceValidator.validateCompleteDeviceList(
destination.get(),
Arrays.stream(messages).map(IncomingDeviceMessage::getDeviceId).collect(Collectors.toSet()),
excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(
destination.get(),
Arrays.stream(messages).collect(Collectors.toMap(
IncomingDeviceMessage::getDeviceId,
IncomingDeviceMessage::getRegistrationId)),
destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
@@ -372,27 +395,29 @@ public class MessageController {
}));
checkAccessKeys(accessKeys, uuidToAccountMap);
final Map<Account, HashSet<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap =
Arrays
.stream(multiRecipientMessage.getRecipients())
.collect(Collectors.toMap(
recipient -> uuidToAccountMap.get(recipient.getUuid()),
recipient -> new HashSet<>(
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
(a, b) -> {
a.addAll(b);
return a;
}
));
final Map<Account, Map<Long, Integer>> accountToDeviceIdAndRegistrationIdMap = Arrays.stream(multiRecipientMessage.getRecipients())
.collect(Collectors.toMap(
recipient -> uuidToAccountMap.get(recipient.getUuid()),
recipient -> Map.of(recipient.getDeviceId(), recipient.getRegistrationId()),
(a, b) -> {
final Map<Long, Integer> combined = new HashMap<>();
combined.putAll(a);
combined.putAll(b);
return combined;
}
));
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
uuidToAccountMap.values().forEach(account -> {
final Set<Pair<Long, Integer>> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account);
final Set<Long> deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet());
final Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).keySet();
try {
MessageValidation.validateCompleteDeviceList(account, deviceIds, false, Optional.empty());
MessageValidation.validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream());
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());
// Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number
// identity
DestinationDeviceValidator.validateRegistrationIds(account, accountToDeviceIdAndRegistrationIdMap.get(account), false);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));

View File

@@ -6,9 +6,11 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import javax.annotation.Nullable;
import javax.validation.constraints.Size;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.util.ExactlySize;
import java.util.OptionalInt;
public class AccountAttributes {
@@ -18,6 +20,10 @@ public class AccountAttributes {
@JsonProperty
private int registrationId;
@Nullable
@JsonProperty("pniRegistrationId")
private Integer phoneNumberIdentityRegistrationId;
@JsonProperty
@Size(max = 204, message = "This field must be less than 50 characters")
private String name;
@@ -59,6 +65,10 @@ public class AccountAttributes {
return registrationId;
}
public OptionalInt getPhoneNumberIdentityRegistrationId() {
return phoneNumberIdentityRegistrationId != null ? OptionalInt.of(phoneNumberIdentityRegistrationId) : OptionalInt.empty();
}
public String getName() {
return name;
}

View File

@@ -5,69 +5,17 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.annotation.Nullable;
import javax.validation.constraints.NotBlank;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import javax.validation.constraints.NotBlank;
public class ChangePhoneNumberRequest {
@JsonProperty
@NotBlank
final String number;
@JsonProperty
@NotBlank
final String code;
@JsonProperty("reglock")
@Nullable
final String registrationLock;
@JsonProperty("device_messages")
@Nullable
final List<IncomingMessage> deviceMessages;
@JsonProperty("device_signed_prekeys")
@Nullable
final Map<Long, SignedPreKey> deviceSignedPrekeys;
@JsonCreator
public ChangePhoneNumberRequest(@JsonProperty("number") final String number,
@JsonProperty("code") final String code,
@JsonProperty("reglock") @Nullable final String registrationLock,
@JsonProperty("device_messages") @Nullable final List<IncomingMessage> deviceMessages,
@JsonProperty("device_signed_prekeys") @Nullable final Map<Long, SignedPreKey> deviceSignedPrekeys) {
this.number = number;
this.code = code;
this.registrationLock = registrationLock;
this.deviceMessages = deviceMessages;
this.deviceSignedPrekeys = deviceSignedPrekeys;
}
public String getNumber() {
return number;
}
public String getCode() {
return code;
}
@Nullable
public String getRegistrationLock() {
return registrationLock;
}
@Nullable
public List<IncomingMessage> getDeviceMessages() {
return deviceMessages;
}
@Nullable
public Map<Long, SignedPreKey> getDeviceSignedPrekeys() {
return deviceSignedPrekeys;
}
public record ChangePhoneNumberRequest(@NotBlank String number,
@NotBlank String code,
@JsonProperty("reglock") @Nullable String registrationLock,
@Nullable String pniIdentityKey,
@Nullable List<IncomingMessage> deviceMessages,
@Nullable Map<Long, SignedPreKey> devicePniSignedPrekeys,
@Nullable Map<Long, Integer> pniRegistrationIds) {
}

View File

@@ -20,6 +20,7 @@ public class OutgoingMessageEntity {
private final UUID sourceUuid;
private final int sourceDevice;
private final UUID destinationUuid;
private final UUID updatedPni;
private final byte[] content;
private final long serverTimestamp;
@@ -31,6 +32,7 @@ public class OutgoingMessageEntity {
@JsonProperty("sourceUuid") final UUID sourceUuid,
@JsonProperty("sourceDevice") final int sourceDevice,
@JsonProperty("destinationUuid") final UUID destinationUuid,
@JsonProperty("updatedPni") final UUID updatedPni,
@JsonProperty("content") final byte[] content,
@JsonProperty("serverTimestamp") final long serverTimestamp)
{
@@ -41,6 +43,7 @@ public class OutgoingMessageEntity {
this.sourceUuid = sourceUuid;
this.sourceDevice = sourceDevice;
this.destinationUuid = destinationUuid;
this.updatedPni = updatedPni;
this.content = content;
this.serverTimestamp = serverTimestamp;
}
@@ -73,6 +76,10 @@ public class OutgoingMessageEntity {
return destinationUuid;
}
public UUID getUpdatedPni() {
return updatedPni;
}
public byte[] getContent() {
return content;
}
@@ -83,23 +90,21 @@ public class OutgoingMessageEntity {
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final OutgoingMessageEntity that = (OutgoingMessageEntity)o;
return type == that.type &&
timestamp == that.timestamp &&
sourceDevice == that.sourceDevice &&
serverTimestamp == that.serverTimestamp &&
guid.equals(that.guid) &&
Objects.equals(source, that.source) &&
Objects.equals(sourceUuid, that.sourceUuid) &&
destinationUuid.equals(that.destinationUuid) &&
Arrays.equals(content, that.content);
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
final OutgoingMessageEntity that = (OutgoingMessageEntity) o;
return type == that.type && timestamp == that.timestamp && sourceDevice == that.sourceDevice
&& serverTimestamp == that.serverTimestamp && guid.equals(that.guid) && Objects.equals(source, that.source)
&& Objects.equals(sourceUuid, that.sourceUuid) && destinationUuid.equals(that.destinationUuid)
&& Objects.equals(updatedPni, that.updatedPni) && Arrays.equals(content, that.content);
}
@Override
public int hashCode() {
int result = Objects.hash(guid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, serverTimestamp);
int result = Objects.hash(guid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, updatedPni,
serverTimestamp);
result = 31 * result + Arrays.hashCode(content);
return result;
}

View File

@@ -19,7 +19,9 @@ import io.micrometer.core.instrument.Tags;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
@@ -28,10 +30,13 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
@@ -39,6 +44,7 @@ import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UsernameValidator;
import org.whispersystems.textsecuregcm.util.Util;
@@ -152,6 +158,7 @@ public class AccountsManager {
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
accountAttributes.getPhoneNumberIdentityRegistrationId().ifPresent(device::setPhoneNumberIdentityRegistrationId);
device.setName(accountAttributes.getName());
device.setCapabilities(accountAttributes.getCapabilities());
device.setCreated(System.currentTimeMillis());
@@ -220,7 +227,11 @@ public class AccountsManager {
}
}
public Account changeNumber(final Account account, final String number) throws InterruptedException {
public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final String originalNumber = account.getNumber();
final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier();
@@ -228,6 +239,22 @@ public class AccountsManager {
return account;
}
if (pniSignedPreKeys != null && pniRegistrationIds != null) {
// Check that all including master ID are in signed pre-keys
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniSignedPreKeys.keySet(),
Collections.emptySet());
// Check that all devices are accounted for in the map of new PNI registration IDs
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniRegistrationIds.keySet(),
Collections.emptySet());
} else if (pniSignedPreKeys != null || pniRegistrationIds != null) {
throw new IllegalArgumentException("Signed pre-keys and registration IDs must both be null or both be non-null");
}
final AtomicReference<Account> updatedAccount = new AtomicReference<>();
deletedAccountsManager.lockAndPut(account.getNumber(), number, (originalAci, deletedAci) -> {
@@ -252,7 +279,22 @@ public class AccountsManager {
try {
numberChangedAccount = updateWithRetries(
account,
a -> true,
a -> {
//noinspection ConstantConditions
if (pniSignedPreKeys != null && pniRegistrationIds != null) {
pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey)));
pniRegistrationIds.forEach((deviceId, registrationId) ->
a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId)));
}
if (pniIdentityKey != null) {
a.setPhoneNumberIdentityKey(pniIdentityKey);
}
return true;
},
a -> accounts.changeNumber(a, number, phoneNumberIdentifier),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);

View File

@@ -6,19 +6,25 @@ package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import javax.validation.constraints.NotNull;
import java.util.List;
import java.util.Map;
import java.util.Optional;
public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(AccountController.class);
@@ -32,35 +38,54 @@ public class ChangeNumberManager {
this.accountsManager = accountsManager;
}
public Account changeNumber(
@NotNull Account account,
@NotNull final String number,
@NotNull final Map<Long, SignedPreKey> deviceSignedPrekeys,
@NotNull final List<IncomingMessage> deviceMessages) throws InterruptedException {
public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Long, Integer> pniRegistrationIds)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException {
if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
assert pniIdentityKey != null;
assert deviceSignedPreKeys != null;
assert deviceMessages != null;
assert pniRegistrationIds != null;
// Check that all except master ID are in device messages
DestinationDeviceValidator.validateCompleteDeviceList(
account,
deviceMessages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()),
Set.of(Device.MASTER_ID));
DestinationDeviceValidator.validateRegistrationIds(
account,
deviceMessages.stream()
.collect(Collectors.toMap(
IncomingMessage::getDestinationDeviceId,
IncomingMessage::getDestinationRegistrationId)),
false);
} else if (!ObjectUtils.allNull(deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
throw new IllegalArgumentException("Signed pre-keys, device messages, and registration IDs must be all null or all non-null");
}
final Account updatedAccount;
if (number.equals(account.getNumber())) {
// This may be a request that got repeated due to poor network conditions or other client error; take no action,
// but report success since the account is in the desired state
updatedAccount = account;
} else {
updatedAccount = accountsManager.changeNumber(account, number);
updatedAccount = accountsManager.changeNumber(account, number, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds);
}
// Whether the account already has this number or not, we reset signed prekeys and resend messages.
// This makes it so the client can resend a request they didn't get a response for (timeout, etc)
// to make sure their messages sent and prekeys were updated, even if the first time around the
// server crashed at/above this point.
if (deviceSignedPrekeys != null && !deviceSignedPrekeys.isEmpty()) {
for (Map.Entry<Long, SignedPreKey> entry : deviceSignedPrekeys.entrySet()) {
accountsManager.updateDevice(updatedAccount, entry.getKey(),
d -> d.setPhoneNumberIdentitySignedPreKey(entry.getValue()));
}
for (IncomingMessage message : deviceMessages) {
sendMessageToSelf(updatedAccount, updatedAccount.getDevice(message.getDestinationDeviceId()), message);
}
// Whether the account already has this number or not, we resend messages. This makes it so the client can resend a
// request they didn't get a response for (timeout, etc) to make sure their messages sent even if the first time
// around the server crashed at/above this point.
if (deviceMessages != null) {
deviceMessages.forEach(message ->
sendMessageToSelf(updatedAccount, updatedAccount.getDevice(message.getDestinationDeviceId()), message));
}
return updatedAccount;
}
@@ -86,6 +111,7 @@ public class ChangeNumberManager {
.setSource(sourceAndDestinationAccount.getNumber())
.setSourceUuid(sourceAndDestinationAccount.getUuid().toString())
.setSourceDevice((int) Device.MASTER_ID)
.setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString())
.build();
messageSender.sendMessage(sourceAndDestinationAccount, destinationDevice.get(), envelope, false);
} catch (NotPushRegisteredException e) {

View File

@@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.OptionalInt;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
@@ -49,6 +50,10 @@ public class Device {
@JsonProperty
private int registrationId;
@Nullable
@JsonProperty("pniRegistrationId")
private Integer phoneNumberIdentityRegistrationId;
@JsonProperty
private SignedPreKey signedPreKey;
@@ -184,6 +189,14 @@ public class Device {
this.registrationId = registrationId;
}
public OptionalInt getPhoneNumberIdentityRegistrationId() {
return phoneNumberIdentityRegistrationId != null ? OptionalInt.of(phoneNumberIdentityRegistrationId) : OptionalInt.empty();
}
public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRegistrationId) {
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
}
public SignedPreKey getSignedPreKey() {
return signedPreKey;
}

View File

@@ -389,6 +389,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasDestinationUuid() ? UUID.fromString(envelope.getDestinationUuid()) : null,
envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
}

View File

@@ -46,6 +46,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
private static final String KEY_SOURCE_UUID = "SU";
private static final String KEY_SOURCE_DEVICE = "SD";
private static final String KEY_DESTINATION_UUID = "DU";
private static final String KEY_UPDATED_PNI = "UP";
private static final String KEY_CONTENT = "C";
private static final String KEY_TTL = "E";
@@ -85,10 +86,12 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
.put(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT, convertLocalIndexMessageUuidSortKey(messageUuid))
.put(KEY_TYPE, AttributeValues.fromInt(message.getType().getNumber()))
.put(KEY_TIMESTAMP, AttributeValues.fromLong(message.getTimestamp()))
.put(KEY_TTL, AttributeValues.fromLong(getTtlForMessage(message)));
item.put(KEY_DESTINATION_UUID, AttributeValues.fromUUID(UUID.fromString(message.getDestinationUuid())));
.put(KEY_TTL, AttributeValues.fromLong(getTtlForMessage(message)))
.put(KEY_DESTINATION_UUID, AttributeValues.fromUUID(UUID.fromString(message.getDestinationUuid())));
if (message.hasUpdatedPni()) {
item.put(KEY_UPDATED_PNI, AttributeValues.fromUUID(UUID.fromString(message.getUpdatedPni())));
}
if (message.hasSource()) {
item.put(KEY_SOURCE, AttributeValues.fromString(message.getSource()));
}
@@ -240,7 +243,9 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
final int sourceDevice = AttributeValues.getInt(message, KEY_SOURCE_DEVICE, 0);
final UUID destinationUuid = AttributeValues.getUUID(message, KEY_DESTINATION_UUID, null);
final byte[] content = AttributeValues.getByteArray(message, KEY_CONTENT, null);
return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, content, sortKey.getServerTimestamp());
final UUID updatedPni = AttributeValues.getUUID(message, KEY_UPDATED_PNI, null);
return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid,
updatedPni, content, sortKey.getServerTimestamp());
}
private void deleteRowsMatchingQuery(AttributeValue partitionKey, QueryRequest querySpec) {

View File

@@ -0,0 +1,95 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
public class DestinationDeviceValidator {
/**
* Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID
* pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the
* destination account are present in the given device ID/registration ID pairs.
*
* @param account the destination account against which to check the given device ID/registration ID pairs
* @param registrationIdsByDeviceId a map of device IDs to registration IDs
* @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device registration IDs
* associated with the account's PNI (if available); compare against the ACI-associated
* registration ID otherwise
*
* @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination
* account does not have a corresponding device or if the registration IDs do not match
*/
public static void validateRegistrationIds(final Account account,
final Map<Long, Integer> registrationIdsByDeviceId,
final boolean usePhoneNumberIdentity) throws StaleDevicesException {
final List<Long> staleDevices = new ArrayList<>();
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
if (registrationId > 0) {
final boolean registrationIdMatches =
account.getDevice(deviceId).map(device -> registrationId == (usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId()))
.orElse(false);
if (!registrationIdMatches) {
staleDevices.add(deviceId);
}
}
});
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
/**
* Validates that the given set of device IDs from a set of messages matches the set of device IDs associated with the
* given destination account in preparation for sending those messages to the destination account. In general, the set
* of device IDs must exactly match the set of active devices associated with the destination account. When sending a
* "sync," message, though, the authenticated account is sending messages from one of their devices to all other
* devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}.
*
* @param account the destination account against which to check the given set of device IDs
* @param messageDeviceIds the set of device IDs to check against the destination account
* @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be
* present in the given set of device IDs (i.e. the device that is sending a sync message)
*
* @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with
* the destination account or is missing entries associated with the destination
* account
*/
public static void validateCompleteDeviceList(final Account account,
final Set<Long> messageDeviceIds,
final Set<Long> excludedDeviceIds) throws MismatchedDevicesException {
final Set<Long> accountDeviceIds = account.getDevices().stream()
.filter(Device::isEnabled)
.map(Device::getId)
.filter(deviceId -> !excludedDeviceIds.contains(deviceId))
.collect(Collectors.toSet());
final Set<Long> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(messageDeviceIds);
final Set<Long> extraDeviceIds = new HashSet<>(messageDeviceIds);
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(new ArrayList<>(missingDeviceIds), new ArrayList<>(extraDeviceIds));
}
}
}

View File

@@ -1,84 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class MessageValidation {
public static <T> void validateRegistrationIds(Account account, List<T> messages, Function<T, Long> getDeviceId, Function<T, Integer> getRegistrationId)
throws StaleDevicesException {
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = messages
.stream()
.map(message -> new Pair<>(getDeviceId.apply(message), getRegistrationId.apply(message)));
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
}
public static void validateRegistrationIds(Account account, Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream)
throws StaleDevicesException {
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
Optional<Device> device = account.getDevice(deviceIdAndRegistrationId.first());
return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId();
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
public static <T> void validateCompleteDeviceList(Account account, Collection<T> messages, Function<T, Long> getDeviceId, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> messageDeviceIds = messages.stream().map(getDeviceId)
.collect(Collectors.toSet());
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId);
}
public static void validateCompleteDeviceList(Account account, Set<Long> messageDeviceIds, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> accountDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();
List<Long> extraDeviceIds = new LinkedList<>();
for (Device device : account.getDevices()) {
if (device.isEnabled() &&
!(isSyncMessage && device.getId() == authenticatedDeviceId.get())) {
accountDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
}
}
}
for (Long deviceId : messageDeviceIds) {
if (!accountDeviceIds.contains(deviceId)) {
extraDeviceIds.add(deviceId);
}
}
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(missingDeviceIds, extraDeviceIds);
}
}
}

View File

@@ -333,8 +333,13 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
builder.setDestinationUuid(message.getDestinationUuid().toString());
if (message.getUpdatedPni() != null) {
builder.setUpdatedPni(message.getUpdatedPni().toString());
}
builder.setServerGuid(message.getGuid().toString());
final Envelope envelope = builder.build();
if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) {