Create separate key stores for different kinds of pre-keys

This commit is contained in:
Jon Chambers
2023-06-06 17:08:26 -04:00
committed by GitHub
parent cac04146de
commit 2b08742c0a
34 changed files with 1482 additions and 847 deletions

View File

@@ -176,7 +176,7 @@ import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagePersister;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
@@ -345,10 +345,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
KeysManager keys = new KeysManager(
dynamoDbAsyncClient,
config.getDynamoDbTables().getEcKeys().getTableName(),
config.getDynamoDbTables().getPqKeys().getTableName(),
config.getDynamoDbTables().getPqLastResortKeys().getTableName());
config.getDynamoDbTables().getKemKeys().getTableName(),
config.getDynamoDbTables().getKemLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration(),

View File

@@ -51,8 +51,8 @@ public class DynamoDbTables {
private final Table deletedAccountsLock;
private final IssuedReceiptsTableConfiguration issuedReceipts;
private final Table ecKeys;
private final Table pqKeys;
private final Table pqLastResortKeys;
private final Table kemKeys;
private final Table kemLastResortKeys;
private final TableWithExpiration messages;
private final Table pendingAccounts;
private final Table pendingDevices;
@@ -72,8 +72,8 @@ public class DynamoDbTables {
@JsonProperty("deletedAccountsLock") final Table deletedAccountsLock,
@JsonProperty("issuedReceipts") final IssuedReceiptsTableConfiguration issuedReceipts,
@JsonProperty("ecKeys") final Table ecKeys,
@JsonProperty("pqKeys") final Table pqKeys,
@JsonProperty("pqLastResortKeys") final Table pqLastResortKeys,
@JsonProperty("pqKeys") final Table kemKeys,
@JsonProperty("pqLastResortKeys") final Table kemLastResortKeys,
@JsonProperty("messages") final TableWithExpiration messages,
@JsonProperty("pendingAccounts") final Table pendingAccounts,
@JsonProperty("pendingDevices") final Table pendingDevices,
@@ -92,8 +92,8 @@ public class DynamoDbTables {
this.deletedAccountsLock = deletedAccountsLock;
this.issuedReceipts = issuedReceipts;
this.ecKeys = ecKeys;
this.pqKeys = pqKeys;
this.pqLastResortKeys = pqLastResortKeys;
this.kemKeys = kemKeys;
this.kemLastResortKeys = kemLastResortKeys;
this.messages = messages;
this.pendingAccounts = pendingAccounts;
this.pendingDevices = pendingDevices;
@@ -140,14 +140,14 @@ public class DynamoDbTables {
@NotNull
@Valid
public Table getPqKeys() {
return pqKeys;
public Table getKemKeys() {
return kemKeys;
}
@NotNull
@Valid
public Table getPqLastResortKeys() {
return pqLastResortKeys;
public Table getKemLastResortKeys() {
return kemLastResortKeys;
}
@NotNull

View File

@@ -51,7 +51,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.util.Pair;
@@ -67,14 +67,14 @@ public class DeviceController {
private final StoredVerificationCodeManager pendingDevices;
private final AccountsManager accounts;
private final MessagesManager messages;
private final Keys keys;
private final KeysManager keys;
private final RateLimiters rateLimiters;
private final Map<String, Integer> maxDeviceConfiguration;
public DeviceController(StoredVerificationCodeManager pendingDevices,
AccountsManager accounts,
MessagesManager messages,
Keys keys,
KeysManager keys,
RateLimiters rateLimiters,
Map<String, Integer> maxDeviceConfiguration) {
this.pendingDevices = pendingDevices;

View File

@@ -53,7 +53,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v2/keys")
@@ -61,7 +61,7 @@ import org.whispersystems.textsecuregcm.storage.Keys;
public class KeysController {
private final RateLimiters rateLimiters;
private final Keys keys;
private final KeysManager keys;
private final AccountsManager accounts;
private static final String IDENTITY_KEY_CHANGE_COUNTER_NAME = name(KeysController.class, "identityKeyChange");
@@ -70,7 +70,7 @@ public class KeysController {
private static final String IDENTITY_TYPE_TAG_NAME = "identityType";
private static final String HAS_IDENTITY_KEY_TAG_NAME = "hasIdentityKey";
public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accounts) {
public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) {
this.rateLimiters = rateLimiters;
this.keys = keys;
this.accounts = accounts;

View File

@@ -48,7 +48,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
@@ -74,18 +74,18 @@ public class RegistrationController {
private final AccountsManager accounts;
private final PhoneVerificationTokenManager phoneVerificationTokenManager;
private final RegistrationLockVerificationManager registrationLockVerificationManager;
private final Keys keys;
private final KeysManager keysManager;
private final RateLimiters rateLimiters;
public RegistrationController(final AccountsManager accounts,
final PhoneVerificationTokenManager phoneVerificationTokenManager,
final RegistrationLockVerificationManager registrationLockVerificationManager,
final Keys keys,
final KeysManager keysManager,
final RateLimiters rateLimiters) {
this.accounts = accounts;
this.phoneVerificationTokenManager = phoneVerificationTokenManager;
this.registrationLockVerificationManager = registrationLockVerificationManager;
this.keys = keys;
this.keysManager = keysManager;
this.rateLimiters = rateLimiters;
}
@@ -176,8 +176,8 @@ public class RegistrationController {
registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
keys.storePqLastResort(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get()));
keys.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get()));
keysManager.storePqLastResort(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get()));
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get()));
});
}

View File

@@ -43,7 +43,7 @@ public record ChangeNumberRequest(
@NotEmpty byte[] pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
A list of synchronization messages to send to companion devices to supply the private keysManager
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,

View File

@@ -36,7 +36,7 @@ public record ChangePhoneNumberRequest(
@Nullable byte[] pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
A list of synchronization messages to send to companion devices to supply the private keysManager
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@Nullable List<IncomingMessage> deviceMessages,

View File

@@ -30,7 +30,7 @@ public record PhoneNumberIdentityKeyDistributionRequest(
@NotNull
@Valid
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
A list of synchronization messages to send to companion devices to supply the private keysManager
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
List<@NotNull @Valid IncomingMessage> deviceMessages,

View File

@@ -90,7 +90,7 @@ public class AccountsManager {
private final FaultTolerantRedisCluster cacheCluster;
private final AccountLockManager accountLockManager;
private final DeletedAccounts deletedAccounts;
private final Keys keys;
private final KeysManager keysManager;
private final MessagesManager messagesManager;
private final ProfilesManager profilesManager;
private final StoredVerificationCodeManager pendingAccounts;
@@ -134,7 +134,7 @@ public class AccountsManager {
final FaultTolerantRedisCluster cacheCluster,
final AccountLockManager accountLockManager,
final DeletedAccounts deletedAccounts,
final Keys keys,
final KeysManager keysManager,
final MessagesManager messagesManager,
final ProfilesManager profilesManager,
final StoredVerificationCodeManager pendingAccounts,
@@ -150,7 +150,7 @@ public class AccountsManager {
this.cacheCluster = cacheCluster;
this.accountLockManager = accountLockManager;
this.deletedAccounts = deletedAccounts;
this.keys = keys;
this.keysManager = keysManager;
this.messagesManager = messagesManager;
this.profilesManager = profilesManager;
this.pendingAccounts = pendingAccounts;
@@ -223,8 +223,8 @@ public class AccountsManager {
// account and need to clear out messages and keys that may have been stored for the old account.
if (!originalUuid.equals(actualUuid)) {
messagesManager.clear(actualUuid);
keys.delete(actualUuid);
keys.delete(account.getPhoneNumberIdentifier());
keysManager.delete(actualUuid);
keysManager.delete(account.getPhoneNumberIdentifier());
profilesManager.deleteAll(actualUuid);
clientPresenceManager.disconnectAllPresencesForUuid(actualUuid);
}
@@ -315,13 +315,13 @@ public class AccountsManager {
updatedAccount.set(numberChangedAccount);
keys.delete(phoneNumberIdentifier);
keys.delete(originalPhoneNumberIdentifier);
keysManager.delete(phoneNumberIdentifier);
keysManager.delete(originalPhoneNumberIdentifier);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(
keysManager.storePqLastResort(
phoneNumberIdentifier,
keys.getPqEnabledDevices(uuid).stream().collect(
keysManager.getPqEnabledDevices(uuid).stream().collect(
Collectors.toMap(
Function.identity(),
pniPqLastResortPreKeys::get)));
@@ -356,10 +356,10 @@ public class AccountsManager {
final UUID pni = account.getPhoneNumberIdentifier();
final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final List<Long> pqEnabledDeviceIDs = keys.getPqEnabledDevices(pni);
keys.delete(pni);
final List<Long> pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni);
keysManager.delete(pni);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get)));
keysManager.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get)));
}
return updatedAccount;
@@ -740,8 +740,8 @@ public class AccountsManager {
account.getUuid());
profilesManager.deleteAll(account.getUuid());
keys.delete(account.getUuid());
keys.delete(account.getPhoneNumberIdentifier());
keysManager.delete(account.getUuid());
keysManager.delete(account.getPhoneNumberIdentifier());
messagesManager.clear(account.getUuid());
messagesManager.clear(account.getPhoneNumberIdentifier());
registrationRecoveryPasswordsManager.removeForNumber(account.getNumber());

View File

@@ -1,417 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Multimap;
import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.Multimaps;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.Counter;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.DeleteRequest;
import software.amazon.awssdk.services.dynamodb.model.PutRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.Select;
import software.amazon.awssdk.services.dynamodb.model.WriteRequest;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class Keys extends AbstractDynamoDbStore {
private final String ecTableName;
private final String pqTableName;
private final String pqLastResortTableName;
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID_KEY_ID = "DK";
static final String KEY_PUBLIC_KEY = "P";
static final String KEY_SIGNATURE = "S";
private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(Keys.class, "storeKeys"));
private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "takeKeyForDevice"));
private static final Timer GET_KEY_COUNT_TIMER = Metrics.timer(name(Keys.class, "getKeyCount"));
private static final Timer DELETE_KEYS_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "deleteKeysForDevice"));
private static final Timer DELETE_KEYS_FOR_ACCOUNT_TIMER = Metrics.timer(name(Keys.class, "deleteKeysForAccount"));
private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION = Metrics.summary(name(Keys.class, "contestedKeys"));
private static final DistributionSummary KEY_COUNT_DISTRIBUTION = Metrics.summary(name(Keys.class, "keyCount"));
private static final Counter KEYS_EMPTY_TAKE_COUNTER = Metrics.counter(name(Keys.class, "takeKeyEmpty"));
private static final Counter TOO_MANY_LAST_RESORT_KEYS_COUNTER = Metrics.counter(name(Keys.class, "tooManyLastResortKeys"));
private static final Counter PARSE_BYTES_FROM_STRING_COUNTER = Metrics.counter(name(Keys.class, "parseByteArray"), "format", "string");
private static final Counter READ_BYTES_FROM_BYTE_ARRAY_COUNTER = Metrics.counter(name(Keys.class, "parseByteArray"), "format", "bytes");
public Keys(
final DynamoDbClient dynamoDB,
final String ecTableName,
final String pqTableName,
final String pqLastResortTableName) {
super(dynamoDB);
this.ecTableName = ecTableName;
this.pqTableName = pqTableName;
this.pqLastResortTableName = pqLastResortTableName;
}
public void store(final UUID identifier, final long deviceId, final List<PreKey> keys) {
store(identifier, deviceId, keys, null, null);
}
public void store(
final UUID identifier, final long deviceId,
@Nullable final List<PreKey> ecKeys,
@Nullable final List<SignedPreKey> pqKeys,
@Nullable final SignedPreKey pqLastResortKey) {
Multimap<String, PreKey> keys = MultimapBuilder.hashKeys().arrayListValues().build();
List<String> tablesToClear = new ArrayList<>();
if (ecKeys != null && !ecKeys.isEmpty()) {
keys.putAll(ecTableName, ecKeys);
tablesToClear.add(ecTableName);
}
if (pqKeys != null && !pqKeys.isEmpty()) {
keys.putAll(pqTableName, pqKeys);
tablesToClear.add(pqTableName);
}
if (pqLastResortKey != null) {
keys.put(pqLastResortTableName, pqLastResortKey);
tablesToClear.add(pqLastResortTableName);
}
STORE_KEYS_TIMER.record(() -> {
delete(tablesToClear, identifier, deviceId);
writeInBatches(
keys.entries(),
batch -> {
Multimap<String, WriteRequest> writes = batch.stream()
.collect(
Multimaps.toMultimap(
Map.Entry<String, PreKey>::getKey,
entry -> WriteRequest.builder()
.putRequest(PutRequest.builder()
.item(getItemFromPreKey(identifier, deviceId, entry.getValue()))
.build())
.build(),
MultimapBuilder.hashKeys().arrayListValues()::build));
executeTableWriteItemsUntilComplete(writes.asMap());
});
});
}
public void storePqLastResort(final UUID identifier, final Map<Long, SignedPreKey> keys) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
final List<WriteRequest> writes = new ArrayList<>(2 * keys.size());
final Map<Long, Map<String, AttributeValue>> newItems = keys.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> getItemFromPreKey(identifier, e.getKey(), e.getValue())));
for (final Map<String, AttributeValue> item : db().query(queryRequest).items()) {
final AttributeValue oldSortKey = item.get(KEY_DEVICE_ID_KEY_ID);
final Long oldDeviceId = oldSortKey.b().asByteBuffer().getLong();
if (newItems.containsKey(oldDeviceId)) {
final Map<String, AttributeValue> replacement = newItems.get(oldDeviceId);
if (!replacement.get(KEY_DEVICE_ID_KEY_ID).equals(oldSortKey)) {
writes.add(WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, oldSortKey))
.build())
.build());
}
}
}
newItems.forEach((unusedKey, item) ->
writes.add(WriteRequest.builder().putRequest(PutRequest.builder().item(item).build()).build()));
executeTableWriteItemsUntilComplete(Map.of(pqLastResortTableName, writes));
}
public Optional<PreKey> takeEC(final UUID identifier, final long deviceId) {
return take(ecTableName, identifier, deviceId);
}
public Optional<SignedPreKey> takePQ(final UUID identifier, final long deviceId) {
return take(pqTableName, identifier, deviceId)
.or(() -> getLastResort(identifier, deviceId))
.map(pk -> (SignedPreKey) pk);
}
private Optional<PreKey> take(final String tableName, final UUID identifier, final long deviceId) {
return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> {
final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", partitionKey,
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(false)
.build();
int contestedKeys = 0;
try {
QueryResponse response = db().query(queryRequest);
for (Map<String, AttributeValue> candidate : response.items()) {
DeleteItemRequest deleteItemRequest = DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, candidate.get(KEY_DEVICE_ID_KEY_ID)))
.returnValues(ReturnValue.ALL_OLD)
.build();
DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest);
if (deleteItemResponse.hasAttributes()) {
return Optional.of(getPreKeyFromItem(deleteItemResponse.attributes()));
}
contestedKeys++;
}
KEYS_EMPTY_TAKE_COUNTER.increment();
return Optional.empty();
} finally {
CONTESTED_KEY_DISTRIBUTION.record(contestedKeys);
}
});
}
@VisibleForTesting
Optional<PreKey> getLastResort(final UUID identifier, final long deviceId) {
final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", partitionKey,
":sortprefix", getSortKeyPrefix(deviceId)))
.consistentRead(false)
.select(Select.ALL_ATTRIBUTES)
.build();
QueryResponse response = db().query(queryRequest);
if (response.count() > 1) {
TOO_MANY_LAST_RESORT_KEYS_COUNTER.increment();
}
return response.items().stream().findFirst().map(this::getPreKeyFromItem);
}
public List<Long> getPqEnabledDevices(final UUID identifier) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(false)
.build();
final QueryResponse response = db().query(queryRequest);
return response.items().stream()
.map(item -> item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong())
.toList();
}
public int getEcCount(final UUID identifier, final long deviceId) {
return getCount(ecTableName, identifier, deviceId);
}
public int getPqCount(final UUID identifier, final long deviceId) {
return getCount(pqTableName, identifier, deviceId);
}
private int getCount(final String tableName, final UUID identifier, final long deviceId) {
return GET_KEY_COUNT_TIMER.record(() -> {
QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier),
":sortprefix", getSortKeyPrefix(deviceId)))
.select(Select.COUNT)
.consistentRead(false)
.build();
int keyCount = 0;
// This is very confusing, but does appear to be the intended behavior. See:
//
// - https://github.com/aws/aws-sdk-java/issues/693
// - https://github.com/aws/aws-sdk-java/issues/915
// - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Count
for (final QueryResponse page : db().queryPaginator(queryRequest)) {
keyCount += page.count();
}
KEY_COUNT_DISTRIBUTION.record(keyCount);
return keyCount;
});
}
public void delete(final UUID accountUuid) {
DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder()
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
deleteItemsForAccountMatchingQuery(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, queryRequest);
});
}
public void delete(final UUID accountUuid, final long deviceId) {
delete(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, deviceId);
}
private void delete(final List<String> tableNames, final UUID accountUuid, final long deviceId) {
DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder()
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid),
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
deleteItemsForAccountMatchingQuery(tableNames, accountUuid, queryRequest);
});
}
private void deleteItemsForAccountMatchingQuery(final List<String> tableNames, final UUID accountUuid, final QueryRequest querySpec) {
final AttributeValue partitionKey = getPartitionKey(accountUuid);
Multimap<String, Map<String, AttributeValue>> itemStream = tableNames.stream()
.collect(
Multimaps.flatteningToMultimap(
Function.identity(),
tableName ->
db().query(querySpec.toBuilder().tableName(tableName).build())
.items()
.stream(),
MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
writeInBatches(
itemStream.entries(),
batch -> {
Multimap<String, WriteRequest> deletes = batch.stream()
.collect(Multimaps.toMultimap(
Map.Entry<String, Map<String, AttributeValue>>::getKey,
entry -> WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, entry.getValue().get(KEY_DEVICE_ID_KEY_ID)))
.build())
.build(),
MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
executeTableWriteItemsUntilComplete(deletes.asMap());
});
}
private static AttributeValue getPartitionKey(final UUID accountUuid) {
return AttributeValues.fromUUID(accountUuid);
}
private static AttributeValue getSortKey(final long deviceId, final long keyId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]);
byteBuffer.putLong(deviceId);
byteBuffer.putLong(keyId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
@VisibleForTesting
static AttributeValue getSortKeyPrefix(final long deviceId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]);
byteBuffer.putLong(deviceId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
private Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) {
if (preKey instanceof final SignedPreKey spk) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, spk.getKeyId()),
KEY_PUBLIC_KEY, AttributeValues.fromByteArray(spk.getPublicKey()),
KEY_SIGNATURE, AttributeValues.fromByteArray(spk.getSignature()));
}
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()),
KEY_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.getPublicKey()));
}
private PreKey getPreKeyFromItem(Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final byte[] publicKey = extractByteArray(item.get(KEY_PUBLIC_KEY));
if (item.containsKey(KEY_SIGNATURE)) {
// All PQ prekeys are signed, and therefore have this attribute. Signed EC prekeys are stored
// in the Accounts table, so EC prekeys retrieved by this class are never SignedPreKeys.
return new SignedPreKey(keyId, publicKey, extractByteArray(item.get(KEY_SIGNATURE)));
}
return new PreKey(keyId, publicKey);
}
/**
* Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string.
*
* @param attributeValue the {@code AttributeValue} from which to extract a byte array
*
* @return the byte array represented by the given {@code AttributeValue}
*/
@VisibleForTesting
static byte[] extractByteArray(final AttributeValue attributeValue) {
if (attributeValue.b() != null) {
READ_BYTES_FROM_BYTE_ARRAY_COUNTER.increment();
return attributeValue.b().asByteArray();
} else if (StringUtils.isNotBlank(attributeValue.s())) {
PARSE_BYTES_FROM_STRING_COUNTER.increment();
return Base64.getDecoder().decode(attributeValue.s());
}
throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value");
}
}

View File

@@ -0,0 +1,111 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
public class KeysManager {
private final SingleUseECPreKeyStore ecPreKeys;
private final SingleUseKEMPreKeyStore pqPreKeys;
private final RepeatedUseSignedPreKeyStore pqLastResortKeys;
public KeysManager(
final DynamoDbAsyncClient dynamoDbAsyncClient,
final String ecTableName,
final String pqTableName,
final String pqLastResortTableName) {
this.ecPreKeys = new SingleUseECPreKeyStore(dynamoDbAsyncClient, ecTableName);
this.pqPreKeys = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, pqTableName);
this.pqLastResortKeys = new RepeatedUseSignedPreKeyStore(dynamoDbAsyncClient, pqLastResortTableName);
}
public void store(final UUID identifier, final long deviceId, final List<PreKey> keys) {
store(identifier, deviceId, keys, null, null);
}
public void store(
final UUID identifier, final long deviceId,
@Nullable final List<PreKey> ecKeys,
@Nullable final List<SignedPreKey> pqKeys,
@Nullable final SignedPreKey pqLastResortKey) {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>();
if (ecKeys != null && !ecKeys.isEmpty()) {
storeFutures.add(ecPreKeys.store(identifier, deviceId, ecKeys));
}
if (pqKeys != null && !pqKeys.isEmpty()) {
storeFutures.add(pqPreKeys.store(identifier, deviceId, pqKeys));
}
if (pqLastResortKey != null) {
storeFutures.add(pqLastResortKeys.store(identifier, deviceId, pqLastResortKey));
}
CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])).join();
}
public void storePqLastResort(final UUID identifier, final Map<Long, SignedPreKey> keys) {
pqLastResortKeys.store(identifier, keys).join();
}
public Optional<PreKey> takeEC(final UUID identifier, final long deviceId) {
return ecPreKeys.take(identifier, deviceId).join();
}
public Optional<SignedPreKey> takePQ(final UUID identifier, final long deviceId) {
return pqPreKeys.take(identifier, deviceId)
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))).join();
}
@VisibleForTesting
Optional<PreKey> getLastResort(final UUID identifier, final long deviceId) {
return pqLastResortKeys.find(identifier, deviceId).join()
.map(signedPreKey -> signedPreKey);
}
public List<Long> getPqEnabledDevices(final UUID identifier) {
return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().block();
}
public int getEcCount(final UUID identifier, final long deviceId) {
return ecPreKeys.getCount(identifier, deviceId).join();
}
public int getPqCount(final UUID identifier, final long deviceId) {
return pqPreKeys.getCount(identifier, deviceId).join();
}
public void delete(final UUID accountUuid) {
CompletableFuture.allOf(
ecPreKeys.delete(accountUuid),
pqPreKeys.delete(accountUuid),
pqLastResortKeys.delete(accountUuid))
.join();
}
public void delete(final UUID accountUuid, final long deviceId) {
CompletableFuture.allOf(
ecPreKeys.delete(accountUuid, deviceId),
pqPreKeys.delete(accountUuid, deviceId),
pqLastResortKeys.delete(accountUuid, deviceId))
.join();
}
}

View File

@@ -0,0 +1,228 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
/**
* A repeated-use signed pre-key store manages storage for pre-keys that may be used more than once. Generally, these
* are considered "last resort" keys and should only be used when a device's supply of single-use pre-keys has been
* exhausted.
* <p/>
* Each {@link Account} may have one or more {@link Device devices}. Each "active" (i.e. those that have completed
* provisioning and are capable of sending and receiving messages) must have exactly one "last resort" pre-key.
*/
public class RepeatedUseSignedPreKeyStore {
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName;
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID = "D";
static final String ATTR_KEY_ID = "I";
static final String ATTR_PUBLIC_KEY = "P";
static final String ATTR_SIGNATURE = "S";
private static final Timer STORE_SINGLE_KEY_TIMER =
Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "storeSingleKey"));
private static final Timer STORE_KEY_BATCH_TIMER =
Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "storeKeyBatch"));
private static final Timer DELETE_FOR_DEVICE_TIMER =
Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "deleteForDevice"));
private static final Timer DELETE_FOR_ACCOUNT_TIMER =
Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "deleteForAccount"));
private static final String FIND_KEY_TIMER_NAME = MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "findKey");
private static final String KEY_PRESENT_TAG_NAME = "keyPresent";
public RepeatedUseSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName;
}
/**
* Stores a repeated-use pre-key for a specific device, displacing any previously-stored repeated-use pre-key for that
* device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @param signedPreKey the key to store for the target device
*
* @return a future that completes once the key has been stored
*/
public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, signedPreKey))
.build())
.thenRun(() -> sample.stop(STORE_SINGLE_KEY_TIMER));
}
/**
* Stores repeated-use pre-keys for a collection of devices associated with a single account/identity, displacing any
* previously-stored repeated-use pre-keys for the targeted devices. Note that this method is transactional; either
* all keys will be stored or none will.
*
* @param identifier the identifier for the account/identity with which the target devices are associated
* @param signedPreKeysByDeviceId a map of device identifiers to pre-keys
*
* @return a future that completes once all keys have been stored
*/
public CompletableFuture<Void> store(final UUID identifier, final Map<Long, SignedPreKey> signedPreKeysByDeviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.transactWriteItems(TransactWriteItemsRequest.builder()
.transactItems(signedPreKeysByDeviceId.entrySet().stream()
.map(entry -> {
final long deviceId = entry.getKey();
final SignedPreKey signedPreKey = entry.getValue();
return TransactWriteItem.builder()
.put(Put.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, signedPreKey))
.build())
.build();
})
.toList())
.build())
.thenRun(() -> sample.stop(STORE_KEY_BATCH_TIMER));
}
/**
* Finds a repeated-use pre-key for a specific device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
*
* @return a future that yields an optional signed pre-key if one is available for the target device or empty if no
* key could be found for the target device
*/
public CompletableFuture<Optional<SignedPreKey>> find(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
final CompletableFuture<Optional<SignedPreKey>> findFuture = dynamoDbAsyncClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.key(getPrimaryKey(identifier, deviceId))
.consistentRead(true)
.build())
.thenApply(response -> response.hasItem() ? Optional.of(getPreKeyFromItem(response.item())) : Optional.empty());
findFuture.whenComplete((maybeSignedPreKey, throwable) ->
sample.stop(Metrics.timer(FIND_KEY_TIMER_NAME, KEY_PRESENT_TAG_NAME, String.valueOf(maybeSignedPreKey != null && maybeSignedPreKey.isPresent()))));
return findFuture;
}
/**
* Clears all repeated-use pre-keys associated with the given account/identity.
*
* @param identifier the identifier for the account/identity for which to clear repeated-use pre-keys
*
* @return a future that completes once repeated-use pre-keys have been cleared from all devices associated with the
* target account/identity
*/
public CompletableFuture<Void> delete(final UUID identifier) {
final Timer.Sample sample = Timer.start();
return getDeviceIdsWithKeys(identifier)
.map(deviceId -> DeleteItemRequest.builder()
.tableName(tableName)
.key(getPrimaryKey(identifier, deviceId))
.build())
.flatMap(deleteItemRequest -> Mono.fromFuture(dynamoDbAsyncClient.deleteItem(deleteItemRequest)))
// Idiom: wait for everything to finish, but discard the results
.reduce(0, (a, b) -> 0)
.toFuture()
.thenRun(() -> sample.stop(DELETE_FOR_ACCOUNT_TIMER));
}
/**
* Removes the repeated-use pre-key associated with a specific device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
*
* @return a future that completes once the repeated-use pre-key has been removed from the target device
*/
public CompletableFuture<Void> delete(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder()
.tableName(tableName)
.key(getPrimaryKey(identifier, deviceId))
.build())
.thenRun(() -> sample.stop(DELETE_FOR_DEVICE_TIMER));
}
public Flux<Long> getDeviceIdsWithKeys(final UUID identifier) {
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier)))
.projectionExpression(KEY_DEVICE_ID)
.consistentRead(true)
.build())
.items())
.map(item -> Long.parseLong(item.get(KEY_DEVICE_ID).n()));
}
private static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final long deviceId) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID, getSortKey(deviceId));
}
private static AttributeValue getPartitionKey(final UUID accountUuid) {
return AttributeValues.fromUUID(accountUuid);
}
private static AttributeValue getSortKey(final long deviceId) {
return AttributeValues.fromLong(deviceId);
}
private static Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final SignedPreKey signedPreKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID, getSortKey(deviceId),
ATTR_KEY_ID, AttributeValues.fromLong(signedPreKey.getKeyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()),
ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature()));
}
private static SignedPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
return new SignedPreKey(
Long.parseLong(item.get(ATTR_KEY_ID).n()),
item.get(ATTR_PUBLIC_KEY).b().asByteArray(),
item.get(ATTR_SIGNATURE).b().asByteArray());
}
}

View File

@@ -0,0 +1,36 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.util.Map;
import java.util.UUID;
public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<PreKey> {
protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName);
}
@Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final PreKey preKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.getPublicKey()));
}
@Override
protected PreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY));
return new PreKey(keyId, publicKey);
}
}

View File

@@ -0,0 +1,38 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.util.Map;
import java.util.UUID;
public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore<SignedPreKey> {
protected SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName);
}
@Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.getKeyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()),
ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature()));
}
@Override
protected SignedPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY));
final byte[] signature = extractByteArray(item.get(ATTR_SIGNATURE));
return new SignedPreKey(keyId, publicKey, signature);
}
}

View File

@@ -0,0 +1,312 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.Select;
/**
* A single-use pre-key store stores single-use pre-keys of a specific type. Keys returned by a single-use pre-key
* store's {@link #take(UUID, long)} method are guaranteed to be returned exactly once, and repeated calls will never
* yield the same key.
* <p/>
* Each {@link Account} may have one or more {@link Device devices}. Clients <em>should</em> regularly check their
* supply of single-use pre-keys (see {@link #getCount(UUID, long)}) and upload new keys when their supply runs low. In
* the event that a party wants to begin a session with a device that has no single-use pre-keys remaining, that party
* may fall back to using the device's repeated-use ("last-resort") signed pre-key instead.
*/
public abstract class SingleUsePreKeyStore<K extends PreKey> {
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName;
private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey"));
private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch"));
private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount"));
private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice"));
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary
.builder(name(getClass(), "keysConsideredForTake"))
.publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999)
.distributionStatisticExpiry(Duration.ofMinutes(10))
.register(Metrics.globalRegistry);
final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary
.builder(name(getClass(), "availableKeyCount"))
.publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999)
.distributionStatisticExpiry(Duration.ofMinutes(10))
.register(Metrics.globalRegistry);
private final String takeKeyTimerName = name(getClass(), "takeKey");
private static final String KEY_PRESENT_TAG_NAME = "keyPresent";
private final Counter parseBytesFromStringCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "string");
private final Counter readBytesFromByteArrayCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "bytes");
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID_KEY_ID = "DK";
static final String ATTR_PUBLIC_KEY = "P";
static final String ATTR_SIGNATURE = "S";
protected SingleUsePreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName;
}
/**
* Stores a batch of single-use pre-keys for a specific device. All previously-stored keys for the device are cleared
* before storing new keys.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @param preKeys a collection of single-use pre-keys to store for the target device
*
* @return a future that completes when all previously-stored keys have been removed and the given collection of
* pre-keys has been stored in its place
*/
public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final List<K> preKeys) {
final Timer.Sample sample = Timer.start();
return delete(identifier, deviceId)
.thenCompose(ignored -> CompletableFuture.allOf(preKeys.stream()
.map(preKey -> store(identifier, deviceId, preKey))
.toList()
.toArray(new CompletableFuture[0])))
.thenRun(() -> sample.stop(storeKeyBatchTimer));
}
private CompletableFuture<Void> store(final UUID identifier, final long deviceId, final K preKey) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, preKey))
.build())
.thenRun(() -> sample.stop(storeKeyTimer));
}
/**
* Attempts to retrieve a single-use pre-key for a specific device. Keys may only be returned by this method at most
* once; once the key is returned, it is removed from the key store and subsequent calls to this method will never
* return the same key.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
*
* @return a future that yields a single-use pre-key if one is available or empty if no single-use pre-keys are
* available for the target device
*/
public CompletableFuture<Optional<K>> take(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
final AttributeValue partitionKey = getPartitionKey(identifier);
final AtomicInteger keysConsidered = new AtomicInteger(0);
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", partitionKey,
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(false)
.build())
.items())
.map(item -> DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID)))
.returnValues(ReturnValue.ALL_OLD)
.build())
.flatMap(deleteItemRequest -> Mono.fromFuture(dynamoDbAsyncClient.deleteItem(deleteItemRequest)), 1)
.doOnNext(deleteItemResponse -> keysConsidered.incrementAndGet())
.filter(DeleteItemResponse::hasAttributes)
.next()
.map(deleteItemResponse -> getPreKeyFromItem(deleteItemResponse.attributes()))
.toFuture()
.thenApply(Optional::ofNullable)
.whenComplete((maybeKey, throwable) -> {
sample.stop(Metrics.timer(takeKeyTimerName, KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent())));
keysConsideredForTakeDistributionSummary.record(keysConsidered.get());
});
}
/**
* Estimates the number of single-use pre-keys available for a given device.
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that yields the approximate number of single-use pre-keys currently available for the target
* device
*/
public CompletableFuture<Integer> getCount(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
// Getting an accurate count from DynamoDB can be very confusing. See:
//
// - https://github.com/aws/aws-sdk-java/issues/693
// - https://github.com/aws/aws-sdk-java/issues/915
// - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Count
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier),
":sortprefix", getSortKeyPrefix(deviceId)))
.select(Select.COUNT)
.consistentRead(false)
.build()))
.map(QueryResponse::count)
.reduce(0, Integer::sum)
.toFuture()
.whenComplete((keyCount, throwable) -> {
sample.stop(getKeyCountTimer);
if (throwable == null && keyCount != null) {
availableKeyCountDistributionSummary.record(keyCount);
}
});
}
/**
* Removes all single-use pre-keys for all devices associated with the given account/identity.
*
* @param identifier the identifier for the account/identity for which to remove single-use pre-keys
*
* @return a future that completes when all single-use pre-keys have been removed for all devices associated with the
* given account/identity
*/
public CompletableFuture<Void> delete(final UUID identifier) {
final Timer.Sample sample = Timer.start();
return deleteItems(getPartitionKey(identifier), Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", getPartitionKey(identifier)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build())
.items()))
.thenRun(() -> sample.stop(deleteForAccountTimer));
}
/**
* Removes all single-use pre-keys for a specific device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that completes when all single-use pre-keys have been removed for the target device
*/
public CompletableFuture<Void> delete(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
return deleteItems(getPartitionKey(identifier), Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier),
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build())
.items()))
.thenRun(() -> sample.stop(deleteForDeviceTimer));
}
private CompletableFuture<Void> deleteItems(final AttributeValue partitionKey, final Flux<Map<String, AttributeValue>> items) {
return items
.map(item -> DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID)
))
.build())
.flatMap(deleteItemRequest -> Mono.fromFuture(dynamoDbAsyncClient.deleteItem(deleteItemRequest)))
// Idiom: wait for everything to finish, but discard the results
.reduce(0, (a, b) -> 0)
.toFuture()
.thenRun(Util.NOOP);
}
protected static AttributeValue getPartitionKey(final UUID accountUuid) {
return AttributeValues.fromUUID(accountUuid);
}
protected static AttributeValue getSortKey(final long deviceId, final long keyId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]);
byteBuffer.putLong(deviceId);
byteBuffer.putLong(keyId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
private static AttributeValue getSortKeyPrefix(final long deviceId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]);
byteBuffer.putLong(deviceId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
protected abstract Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId,
final K preKey);
protected abstract K getPreKeyFromItem(final Map<String, AttributeValue> item);
/**
* Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string.
*
* @param attributeValue the {@code AttributeValue} from which to extract a byte array
*
* @return the byte array represented by the given {@code AttributeValue}
*/
@VisibleForTesting
byte[] extractByteArray(final AttributeValue attributeValue) {
if (attributeValue.b() != null) {
readBytesFromByteArrayCounter.increment();
return attributeValue.b().asByteArray();
} else if (StringUtils.isNotBlank(attributeValue.s())) {
parseBytesFromStringCounter.increment();
return Base64.getDecoder().decode(attributeValue.s());
}
throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value");
}
}

View File

@@ -42,7 +42,7 @@ import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@@ -171,10 +171,11 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
KeysManager keys = new KeysManager(
dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
configuration.getDynamoDbTables().getKemKeys().getTableName(),
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@@ -36,7 +36,7 @@ import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@@ -66,7 +66,7 @@ record CommandDependencies(
MessagesManager messagesManager,
StoredVerificationCodeManager pendingAccountsManager,
ClientPresenceManager clientPresenceManager,
Keys keys,
KeysManager keysManager,
FaultTolerantRedisCluster cacheCluster,
ClientResources redisClusterClientResources) {
@@ -153,10 +153,11 @@ record CommandDependencies(
configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
KeysManager keys = new KeysManager(
dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
configuration.getDynamoDbTables().getKemKeys().getTableName(),
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@@ -65,7 +65,7 @@ public class UnlinkDeviceCommand extends EnvironmentCommand<WhisperServerConfigu
account = deps.accountsManager().update(account, a -> a.removeDevice(deviceId));
System.out.format("Removing keys for device %s::%d\n", aci, deviceId);
deps.keys().delete(account.getUuid(), deviceId);
deps.keysManager().delete(account.getUuid(), deviceId);
System.out.format("Clearing additional messages for %s::%d\n", aci, deviceId);
deps.messagesManager().clear(account.getUuid(), deviceId);