Update signed pre-keys in transactions

This commit is contained in:
Jon Chambers
2023-12-05 14:20:16 -05:00
committed by GitHub
parent ede9297139
commit df421e0182
8 changed files with 415 additions and 252 deletions

View File

@@ -14,9 +14,11 @@ import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
@@ -67,6 +69,8 @@ public class KeysController {
private final AccountsManager accounts;
private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys");
private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0];
public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) {
this.rateLimiters = rateLimiters;
this.keys = keys;
@@ -110,24 +114,51 @@ public class KeysController {
description="whether this operation applies to the account (aci) or phone-number (pni) identity")
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
Account account = disabledPermittedAuth.getAccount();
final Account account = disabledPermittedAuth.getAccount();
final Device device = disabledPermittedAuth.getAuthenticatedDevice();
final UUID identifier = account.getIdentifier(identityType);
checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType));
final CompletableFuture<Account> updateAccountFuture;
if (setKeysRequest.signedPreKey() != null &&
!setKeysRequest.signedPreKey().equals(device.getSignedPreKey(identityType))) {
account = accounts.update(account, a -> a.getDevice(device.getId()).ifPresent(d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey());
case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey());
}
}));
updateAccountFuture = accounts.updateDeviceTransactionallyAsync(account,
device.getId(),
d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey());
case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey());
}
},
d -> keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), setKeysRequest.signedPreKey())
.map(List::of)
.orElseGet(Collections::emptyList))
.toCompletableFuture();
} else {
updateAccountFuture = CompletableFuture.completedFuture(account);
}
return keys.store(account.getIdentifier(identityType), device.getId(),
setKeysRequest.preKeys(), setKeysRequest.pqPreKeys(), setKeysRequest.signedPreKey(), setKeysRequest.pqLastResortPreKey())
return updateAccountFuture.thenCompose(updatedAccount -> {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(3);
if (setKeysRequest.preKeys() != null) {
storeFutures.add(keys.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys()));
}
if (setKeysRequest.pqPreKeys() != null) {
storeFutures.add(keys.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys()));
}
if (setKeysRequest.pqLastResortPreKey() != null) {
storeFutures.add(
keys.storePqLastResort(identifier, Map.of(device.getId(), setKeysRequest.pqLastResortPreKey())));
}
return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY));
})
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
@@ -265,17 +296,21 @@ public class KeysController {
@Valid final ECSignedPreKey signedPreKey,
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
Device device = auth.getAuthenticatedDevice();
final UUID identifier = auth.getAccount().getIdentifier(identityType);
final byte deviceId = auth.getAuthenticatedDevice().getId();
accounts.updateDevice(auth.getAccount(), device.getId(), d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(signedPreKey);
case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey);
}
});
return keys.storeEcSignedPreKeys(auth.getAccount().getIdentifier(identityType),
Map.of(device.getId(), signedPreKey))
return accounts.updateDeviceTransactionallyAsync(auth.getAccount(),
deviceId,
d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(signedPreKey);
case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey);
}
},
d -> keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), signedPreKey)
.map(List::of)
.orElseGet(Collections::emptyList))
.toCompletableFuture()
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}

View File

@@ -95,6 +95,7 @@ public class Accounts extends AbstractDynamoDbStore {
private static final Timer RESERVE_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "reserveUsername"));
private static final Timer CLEAR_USERNAME_HASH_TIMER = Metrics.timer(name(Accounts.class, "clearUsernameHash"));
private static final Timer UPDATE_TIMER = Metrics.timer(name(Accounts.class, "update"));
private static final Timer UPDATE_TRANSACTIONALLY_TIMER = Metrics.timer(name(Accounts.class, "updateTransactionally"));
private static final Timer RECLAIM_TIMER = Metrics.timer(name(Accounts.class, "reclaim"));
private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "getByNumber"));
private static final Timer GET_BY_USERNAME_HASH_TIMER = Metrics.timer(name(Accounts.class, "getByUsernameHash"));
@@ -277,6 +278,7 @@ public class Accounts extends AbstractDynamoDbStore {
!existingAccount.getNumber().equals(accountToCreate.getNumber())) {
throw new IllegalArgumentException("reclaimed accounts must match");
}
return AsyncTimerUtil.record(RECLAIM_TIMER, () -> {
accountToCreate.setVersion(existingAccount.getVersion());
@@ -364,7 +366,8 @@ public class Accounts extends AbstractDynamoDbStore {
public void changeNumber(final Account account,
final String number,
final UUID phoneNumberIdentifier,
final Optional<UUID> maybeDisplacedAccountIdentifier) {
final Optional<UUID> maybeDisplacedAccountIdentifier,
final Collection<TransactWriteItem> additionalWriteItems) {
CHANGE_NUMBER_TIMER.record(() -> {
final String originalNumber = account.getNumber();
@@ -413,6 +416,8 @@ public class Accounts extends AbstractDynamoDbStore {
.build())
.build());
writeItems.addAll(additionalWriteItems);
final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(writeItems)
.build();
@@ -863,6 +868,35 @@ public class Accounts extends AbstractDynamoDbStore {
joinAndUnwrapUpdateFuture(updateAsync(account));
}
public CompletionStage<Void> updateTransactionallyAsync(final Account account,
final Collection<TransactWriteItem> additionalWriteItems) {
return AsyncTimerUtil.record(UPDATE_TRANSACTIONALLY_TIMER, () -> {
final List<TransactWriteItem> writeItems = new ArrayList<>(additionalWriteItems.size() + 1);
writeItems.add(UpdateAccountSpec.forAccount(accountsTableName, account).transactItem());
writeItems.addAll(additionalWriteItems);
return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder()
.transactItems(writeItems)
.build())
.thenApply(response -> {
account.setVersion(account.getVersion() + 1);
return (Void) null;
})
.exceptionally(throwable -> {
final Throwable unwrapped = ExceptionUtils.unwrap(throwable);
if (unwrapped instanceof TransactionCanceledException transactionCanceledException) {
if ("ConditionalCheckFailed".equals(transactionCanceledException.cancellationReasons().get(0).code())) {
throw new ContestedOptimisticLockException();
}
}
throw CompletableFutureUtils.errorAsCompletionException(throwable);
});
});
}
public CompletableFuture<Boolean> usernameHashAvailable(final byte[] username) {
return usernameHashAvailable(Optional.empty(), username);
}

View File

@@ -27,6 +27,7 @@ import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -39,10 +40,10 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
@@ -71,6 +72,7 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.ParallelFlux;
import reactor.core.scheduler.Scheduler;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
public class AccountsManager {
@@ -365,39 +367,25 @@ public class AccountsManager {
final UUID uuid = account.getUuid();
final UUID phoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber);
final Account numberChangedAccount;
numberChangedAccount = updateWithRetries(
account,
a -> {
setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds);
return true;
},
a -> accounts.changeNumber(a, targetNumber, phoneNumberIdentifier, maybeDisplacedUuid),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
updatedAccount.set(numberChangedAccount);
CompletableFuture.allOf(
keysManager.delete(phoneNumberIdentifier),
keysManager.delete(originalPhoneNumberIdentifier))
.join();
keysManager.storeEcSignedPreKeys(phoneNumberIdentifier, pniSignedPreKeys);
final Collection<TransactWriteItem> keyWriteItems =
buildKeyWriteItems(uuid, phoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys);
if (pniPqLastResortPreKeys != null) {
keysManager.getPqEnabledDevices(uuid).thenCompose(
deviceIds -> keysManager.storePqLastResort(
phoneNumberIdentifier,
deviceIds.stream()
.filter(pniPqLastResortPreKeys::containsKey)
.collect(
Collectors.toMap(
Function.identity(),
pniPqLastResortPreKeys::get))))
.join();
}
final Account numberChangedAccount = updateWithRetries(
account,
a -> {
setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds);
return true;
},
a -> accounts.changeNumber(a, targetNumber, phoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
updatedAccount.set(numberChangedAccount);
});
return updatedAccount.get();
@@ -410,31 +398,58 @@ public class AccountsManager {
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
final UUID pni = account.getPhoneNumberIdentifier();
final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final UUID aci = account.getIdentifier(IdentityType.ACI);
final UUID pni = account.getIdentifier(IdentityType.PNI);
final Collection<TransactWriteItem> keyWriteItems =
buildKeyWriteItems(pni, pni, pniSignedPreKeys, pniPqLastResortPreKeys);
return redisDeleteAsync(account)
.thenCompose(ignored -> keysManager.delete(pni))
.thenCompose(ignored -> updateTransactionallyWithRetriesAsync(account,
a -> setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds),
accounts::updateTransactionallyAsync,
() -> accounts.getByAccountIdentifierAsync(aci).thenApply(Optional::orElseThrow),
a -> keyWriteItems,
AccountChangeValidator.GENERAL_CHANGE_VALIDATOR,
MAX_UPDATE_ATTEMPTS))
.join();
}
private Collection<TransactWriteItem> buildKeyWriteItems(
final UUID enabledDevicesIdentifier,
final UUID phoneNumberIdentifier,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys) {
final List<TransactWriteItem> keyWriteItems = new ArrayList<>();
if (pniSignedPreKeys != null) {
pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
keysManager.buildWriteItemForEcSignedPreKey(phoneNumberIdentifier, deviceId, signedPreKey)
.ifPresent(keyWriteItems::add));
}
keysManager.delete(pni);
keysManager.storeEcSignedPreKeys(pni, pniSignedPreKeys).join();
if (pniPqLastResortPreKeys != null) {
keysManager.getPqEnabledDevices(pni)
.thenCompose(
deviceIds -> keysManager.storePqLastResort(
pni,
deviceIds.stream()
.filter(pniPqLastResortPreKeys::containsKey)
.collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get))))
keysManager.getPqEnabledDevices(enabledDevicesIdentifier)
.thenAccept(deviceIds -> deviceIds.stream()
.filter(pniPqLastResortPreKeys::containsKey)
.map(deviceId -> keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier,
deviceId,
pniPqLastResortPreKeys.get(deviceId)))
.forEach(keyWriteItems::add))
.join();
}
return updatedAccount;
return keyWriteItems;
}
private boolean setPniKeys(final Account account,
private void setPniKeys(final Account account,
@Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
return false;
return;
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null");
}
@@ -455,8 +470,6 @@ public class AccountsManager {
}
account.setPhoneNumberIdentityKey(pniIdentityKey);
return changed;
}
private void validateDevices(final Account account,
@@ -777,6 +790,42 @@ public class AccountsManager {
return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException());
}
private CompletionStage<Account> updateTransactionallyWithRetriesAsync(final Account account,
final Consumer<Account> updater,
final BiFunction<Account, Collection<TransactWriteItem>, CompletionStage<Void>> persister,
final Supplier<CompletionStage<Account>> retriever,
final Function<Account, Collection<TransactWriteItem>> additionalWriteItemProvider,
final AccountChangeValidator changeValidator,
final int remainingTries) {
final Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
final Collection<TransactWriteItem> additionalWriteItems = additionalWriteItemProvider.apply(account);
updater.accept(account);
if (remainingTries > 0) {
return persister.apply(account, additionalWriteItems)
.thenApply(ignored -> {
final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
account.markStale();
changeValidator.validateChange(originalAccount, updatedAccount);
return updatedAccount;
})
.exceptionallyCompose(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException) {
return retriever.get().thenCompose(refreshedAccount ->
updateTransactionallyWithRetriesAsync(refreshedAccount, updater, persister, retriever, additionalWriteItemProvider, changeValidator, remainingTries - 1));
} else {
throw ExceptionUtils.wrap(throwable);
}
});
}
return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException());
}
public Account updateDevice(Account account, byte deviceId, Consumer<Device> deviceUpdater) {
return update(account, a -> {
a.getDevice(deviceId).ifPresent(deviceUpdater);
@@ -794,6 +843,22 @@ public class AccountsManager {
});
}
public CompletionStage<Account> updateDeviceTransactionallyAsync(final Account account,
final byte deviceId,
final Consumer<Device> deviceUpdater,
final Function<Device, Collection<TransactWriteItem>> additionalWriteItemProvider) {
final UUID uuid = account.getUuid();
return redisDeleteAsync(account).thenCompose(ignored -> updateTransactionallyWithRetriesAsync(account,
a -> a.getDevice(deviceId).ifPresent(deviceUpdater),
accounts::updateTransactionallyAsync,
() -> accounts.getByAccountIdentifierAsync(uuid).thenApply(Optional::orElseThrow),
a -> additionalWriteItemProvider.apply(a.getDevice(deviceId).orElseThrow()),
AccountChangeValidator.GENERAL_CHANGE_VALIDATOR,
MAX_UPDATE_ATTEMPTS));
}
public Optional<Account> getByE164(final String number) {
return checkRedisThenAccounts(
getByNumberTimer,

View File

@@ -6,13 +6,11 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
@@ -43,34 +41,20 @@ public class KeysManager {
this.dynamicConfigurationManager = dynamicConfigurationManager;
}
public CompletableFuture<Void> store(
final UUID identifier, final byte deviceId,
@Nullable final List<ECPreKey> ecKeys,
@Nullable final List<KEMSignedPreKey> pqKeys,
@Nullable final ECSignedPreKey ecSignedPreKey,
@Nullable final KEMSignedPreKey pqLastResortKey) {
public Optional<TransactWriteItem> buildWriteItemForEcSignedPreKey(final UUID identifier,
final byte deviceId,
final ECSignedPreKey ecSignedPreKey) {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>();
return dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()
? Optional.of(ecSignedPreKeys.buildTransactWriteItem(identifier, deviceId, ecSignedPreKey))
: Optional.empty();
}
if (ecKeys != null && !ecKeys.isEmpty()) {
storeFutures.add(ecPreKeys.store(identifier, deviceId, ecKeys));
}
public TransactWriteItem buildWriteItemForLastResortKey(final UUID identifier,
final byte deviceId,
final KEMSignedPreKey lastResortSignedPreKey) {
if (pqKeys != null && !pqKeys.isEmpty()) {
storeFutures.add(pqPreKeys.store(identifier, deviceId, pqKeys));
}
if (ecSignedPreKey != null
&& dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) {
storeFutures.add(ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey));
}
if (pqLastResortKey != null) {
storeFutures.add(pqLastResortKeys.store(identifier, deviceId, pqLastResortKey));
}
return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0]));
return pqLastResortKeys.buildTransactWriteItem(identifier, deviceId, lastResortSignedPreKey);
}
public List<TransactWriteItem> buildWriteItemsForRepeatedUseKeys(final UUID accountIdentifier,