Add devices to accounts transactionally

This commit is contained in:
Jon Chambers
2023-12-07 11:19:40 -05:00
committed by GitHub
parent e084a9f2b6
commit 50d92265ea
10 changed files with 520 additions and 268 deletions

View File

@@ -68,6 +68,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair;
@@ -403,60 +404,63 @@ public class DeviceController {
throw new WebApplicationException(Response.status(409).build());
}
final Device device = new Device();
device.setName(accountAttributes.getName());
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());
return maybeDeviceActivationRequest.map(deviceActivationRequest -> {
final String signalAgent;
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey());
if (deviceActivationRequest.apnToken().isPresent()) {
signalAgent = "OWP";
} else if (deviceActivationRequest.gcmToken().isPresent()) {
signalAgent = "OWA";
} else {
signalAgent = "OWD";
}
deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
return accounts.addDevice(account, new DeviceSpec(accountAttributes.getName(),
password,
signalAgent,
capabilities,
accountAttributes.getRegistrationId(),
accountAttributes.getPhoneNumberIdentityRegistrationId(),
accountAttributes.getFetchesMessages(),
deviceActivationRequest.apnToken(),
deviceActivationRequest.gcmToken(),
deviceActivationRequest.aciSignedPreKey(),
deviceActivationRequest.pniSignedPreKey(),
deviceActivationRequest.aciPqLastResortPreKey(),
deviceActivationRequest.pniPqLastResortPreKey()))
.thenCompose(a -> usedTokenCluster.withCluster(connection -> connection.async()
.set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)))
.thenApply(ignored -> a))
.join();
})
.orElseGet(() -> {
final Device device = new Device();
device.setName(accountAttributes.getName());
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());
deviceActivationRequest.gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
});
final Account updatedAccount = accounts.update(account, a -> {
device.setId(a.getNextDeviceId());
final Account updatedAccount = accounts.update(account, a -> {
device.setId(a.getNextDeviceId());
CompletableFuture.allOf(
keys.delete(a.getUuid(), device.getId()),
keys.delete(a.getPhoneNumberIdentifier(), device.getId()),
messages.clear(a.getUuid(), device.getId()))
.join();
final CompletableFuture<Void> deleteKeysFuture = CompletableFuture.allOf(
keys.delete(a.getUuid(), device.getId()),
keys.delete(a.getPhoneNumberIdentifier(), device.getId()));
a.addDevice(device);
});
messages.clear(a.getUuid(), device.getId()).join();
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
deleteKeysFuture.join();
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey())),
keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey())),
keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey())))
.join());
a.addDevice(device);
});
if (maybeAciFromToken.isPresent()) {
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
}
return new Pair<>(updatedAccount, device);
return new Pair<>(updatedAccount, device);
});
}
private static String getUsedTokenKey(final String token) {

View File

@@ -43,6 +43,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
@@ -140,18 +141,24 @@ public class RegistrationController {
}
final Account account = accounts.create(number,
password,
signalAgent,
registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new),
registrationRequest.aciIdentityKey(),
registrationRequest.pniIdentityKey(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().gcmToken());
new DeviceSpec(
registrationRequest.accountAttributes().getName(),
password,
signalAgent,
registrationRequest.accountAttributes().getCapabilities(),
registrationRequest.accountAttributes().getRegistrationId(),
registrationRequest.accountAttributes().getPhoneNumberIdentityRegistrationId(),
registrationRequest.accountAttributes().getFetchesMessages(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().gcmToken(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey()));
Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),

View File

@@ -53,9 +53,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
@@ -68,6 +66,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.ParallelFlux;
@@ -132,11 +131,6 @@ public class AccountsManager {
private static final int MAX_UPDATE_ATTEMPTS = 10;
@FunctionalInterface
private interface AccountPersister {
void persistAccount(Account account) throws UsernameHashNotAvailableException;
}
public enum DeletionReason {
ADMIN_DELETED("admin"),
EXPIRED ("expired"),
@@ -181,46 +175,18 @@ public class AccountsManager {
this.clock = requireNonNull(clock);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public Account create(final String number,
final String password,
final String signalAgent,
final AccountAttributes accountAttributes,
final List<AccountBadge> accountBadges,
final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey,
final Optional<ApnRegistrationId> maybeApnRegistrationId,
final Optional<GcmRegistrationId> maybeGcmRegistrationId) throws InterruptedException {
final DeviceSpec primaryDeviceSpec) throws InterruptedException {
try (Timer.Context ignored = createTimer.time()) {
final Account account = new Account();
accountLockManager.withLock(List.of(number), () -> {
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setName(accountAttributes.getName());
device.setCapabilities(accountAttributes.getCapabilities());
device.setCreated(System.currentTimeMillis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent);
device.setSignedPreKey(aciSignedPreKey);
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey);
maybeApnRegistrationId.ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
maybeGcmRegistrationId.ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
final Device device = primaryDeviceSpec.toDevice(Device.PRIMARY_ID, clock);
account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number));
@@ -245,10 +211,10 @@ public class AccountsManager {
a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI),
a.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey),
primaryDeviceSpec.aciSignedPreKey(),
primaryDeviceSpec.pniSignedPreKey(),
primaryDeviceSpec.aciPqLastResortPreKey(),
primaryDeviceSpec.pniPqLastResortPreKey()),
(aci, pni) -> CompletableFuture.allOf(
keysManager.delete(aci),
keysManager.delete(pni),
@@ -299,6 +265,42 @@ public class AccountsManager {
}
}
public CompletableFuture<Pair<Account, Device>> addDevice(final Account account, final DeviceSpec deviceSpec) {
return addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, MAX_UPDATE_ATTEMPTS);
}
private CompletableFuture<Pair<Account, Device>> addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final int retries) {
return accounts.getByAccountIdentifierAsync(accountIdentifier)
.thenApply(maybeAccount -> maybeAccount.orElseThrow(ContestedOptimisticLockException::new))
.thenCompose(account -> {
final byte nextDeviceId = account.getNextDeviceId();
account.addDevice(deviceSpec.toDevice(nextDeviceId, clock));
final List<TransactWriteItem> additionalWriteItems = keysManager.buildWriteItemsForRepeatedUseKeys(
account.getIdentifier(IdentityType.ACI),
account.getIdentifier(IdentityType.PNI),
nextDeviceId,
deviceSpec.aciSignedPreKey(),
deviceSpec.pniSignedPreKey(),
deviceSpec.aciPqLastResortPreKey(),
deviceSpec.pniPqLastResortPreKey());
return CompletableFuture.allOf(
keysManager.delete(account.getUuid(), nextDeviceId),
keysManager.delete(account.getPhoneNumberIdentifier(), nextDeviceId),
messagesManager.clear(account.getUuid(), nextDeviceId))
.thenCompose(ignored -> accounts.updateTransactionallyAsync(account, additionalWriteItems))
.thenApply(ignored -> new Pair<>(account, account.getDevice(nextDeviceId).orElseThrow()));
})
.exceptionallyCompose(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException && retries > 0) {
return addDevice(accountIdentifier, deviceSpec, retries - 1);
}
return CompletableFuture.failedFuture(throwable);
});
}
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) {
throw new IllegalArgumentException("Cannot remove primary device");
@@ -705,19 +707,6 @@ public class AccountsManager {
final Consumer<Account> persister,
final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) {
try {
return failableUpdateWithRetries(account, updater, persister::accept, retriever, changeValidator);
} catch (UsernameHashNotAvailableException e) {
// not possible
throw new IllegalStateException(e);
}
}
private Account failableUpdateWithRetries(Account account,
final Function<Account, Boolean> updater,
final AccountPersister persister,
final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException {
Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
@@ -731,7 +720,7 @@ public class AccountsManager {
while (tries < maxTries) {
try {
persister.persistAccount(account);
persister.accept(account);
final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
account.markStale();

View File

@@ -0,0 +1,90 @@
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.util.Util;
import java.time.Clock;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
public record DeviceSpec(
byte[] deviceNameCiphertext,
String password,
String signalAgent,
Device.DeviceCapabilities capabilities,
int aciRegistrationId,
int pniRegistrationId,
boolean fetchesMessages,
Optional<ApnRegistrationId> apnRegistrationId,
Optional<GcmRegistrationId> gcmRegistrationId,
ECSignedPreKey aciSignedPreKey,
ECSignedPreKey pniSignedPreKey,
KEMSignedPreKey aciPqLastResortPreKey,
KEMSignedPreKey pniPqLastResortPreKey) {
public Device toDevice(final byte deviceId, final Clock clock) {
final Device device = new Device();
device.setId(deviceId);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password()));
device.setFetchesMessages(fetchesMessages());
device.setRegistrationId(aciRegistrationId());
device.setPhoneNumberIdentityRegistrationId(pniRegistrationId());
device.setName(deviceNameCiphertext());
device.setCapabilities(capabilities());
device.setCreated(clock.millis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent());
device.setSignedPreKey(aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey());
apnRegistrationId().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
gcmRegistrationId().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
return device;
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final DeviceSpec that = (DeviceSpec) o;
return aciRegistrationId == that.aciRegistrationId
&& pniRegistrationId == that.pniRegistrationId
&& fetchesMessages == that.fetchesMessages
&& Arrays.equals(deviceNameCiphertext, that.deviceNameCiphertext)
&& Objects.equals(password, that.password)
&& Objects.equals(signalAgent, that.signalAgent)
&& Objects.equals(capabilities, that.capabilities)
&& Objects.equals(apnRegistrationId, that.apnRegistrationId)
&& Objects.equals(gcmRegistrationId, that.gcmRegistrationId)
&& Objects.equals(aciSignedPreKey, that.aciSignedPreKey)
&& Objects.equals(pniSignedPreKey, that.pniSignedPreKey)
&& Objects.equals(aciPqLastResortPreKey, that.aciPqLastResortPreKey)
&& Objects.equals(pniPqLastResortPreKey, that.pniPqLastResortPreKey);
}
@Override
public int hashCode() {
int result = Objects.hash(password, signalAgent, capabilities, aciRegistrationId, pniRegistrationId,
fetchesMessages, apnRegistrationId, gcmRegistrationId, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey);
result = 31 * result + Arrays.hashCode(deviceNameCiphertext);
return result;
}
}