Add a top-level uak to existing items

Items wirtten before we started storing the uak at
the top level only store the uak in the
account blob. The will be updated on account
crawl
This commit is contained in:
Ravi Khadiwala
2022-03-01 14:59:26 -06:00
committed by ravi-signal
parent 6283f5952d
commit 9cb098ad8a
2 changed files with 184 additions and 7 deletions

View File

@@ -8,6 +8,8 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
@@ -28,6 +30,11 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementResponse;
import software.amazon.awssdk.services.dynamodb.model.BatchStatementError;
import software.amazon.awssdk.services.dynamodb.model.BatchStatementRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchStatementResponse;
import software.amazon.awssdk.services.dynamodb.model.CancellationReason;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.Delete;
@@ -84,6 +91,10 @@ public class Accounts extends AbstractDynamoDbStore {
private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom"));
private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset"));
private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete"));
private static final Timer NORMALIZE_ITEM_TIMER = Metrics.timer(name(Accounts.class, "normalizeItem"));
private static final Counter UAK_NORMALIZE_SUCCESS_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakSuccess"));
private static final Counter UAK_NORMALIZE_ERROR_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakError"));
private static final Logger log = LoggerFactory.getLogger(Accounts.class);
@@ -627,15 +638,67 @@ public class Accounts extends AbstractDynamoDbStore {
return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_START_TIMER);
}
private List<Account> normalizeIfRequired(final List<Map<String, AttributeValue>> items) {
// The UAK top-level attribute may not exist on older records,
// if it is absent and there is a UAK in the account blob we'll
// add the UAK as a top-level attribute
// TODO: Can eliminate this once all uaks exist as top-level attributes
final List<Account> allAccounts = new ArrayList<>();
final List<Account> accountsToNormalize = new ArrayList<>();
for (Map<String, AttributeValue> item : items) {
final Account account = fromItem(item);
allAccounts.add(account);
if (!item.containsKey(ATTR_UAK) && account.getUnidentifiedAccessKey().isPresent()) {
// the top level uak attribute doesn't exist, but there's a uak in the account
accountsToNormalize.add(account);
}
}
final int BATCH_SIZE = 25; // dynamodb max batch size
final String updateUakStatement = String.format("UPDATE %s SET %s = ? WHERE %s = ?", accountsTableName, ATTR_UAK, KEY_ACCOUNT_UUID);
for (List<Account> toNormalize : Lists.partition(accountsToNormalize, BATCH_SIZE)) {
NORMALIZE_ITEM_TIMER.record(() -> {
try {
final List<BatchStatementRequest> updateStatements = toNormalize.stream()
.map(account -> BatchStatementRequest.builder()
.statement(updateUakStatement)
.parameters(
AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()),
AttributeValues.fromUUID(account.getUuid()))
.build())
.toList();
final BatchExecuteStatementResponse result = client.batchExecuteStatement(BatchExecuteStatementRequest
.builder()
.statements(updateStatements)
.build());
final Map<String, Long> errors = result.responses().stream()
.map(BatchStatementResponse::error)
.filter(e -> e != null)
.collect(Collectors.groupingBy(BatchStatementError::codeAsString, Collectors.counting()));
final long errorCount = errors.values().stream().mapToLong(Long::longValue).sum();
UAK_NORMALIZE_SUCCESS_COUNT.increment(toNormalize.size() - errorCount);
UAK_NORMALIZE_ERROR_COUNT.increment(errorCount);
if (!errors.isEmpty()) {
log.warn("Failed to normalize account uaks in batch of {}, error codes: {}", toNormalize.size(), errors);
}
} catch (final Exception e) {
UAK_NORMALIZE_ERROR_COUNT.increment(toNormalize.size());
log.warn("Failed to normalize accounts in a batch of {}", toNormalize.size(), e);
}
});
}
return allAccounts;
}
private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) {
scanRequestBuilder.tableName(accountsTableName);
final List<Account> accounts = timer.record(() -> scan(scanRequestBuilder.build(), maxCount)
.stream()
.map(Accounts::fromItem)
.collect(Collectors.toList()));
final List<Map<String, AttributeValue>> items = timer.record(() -> scan(scanRequestBuilder.build(), maxCount));
final List<Account> accounts = normalizeIfRequired(items);
return new AccountCrawlChunk(accounts, accounts.size() > 0 ? accounts.get(accounts.size() - 1).getUuid() : null);
}