mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-21 20:28:06 +01:00
Use pre-calculated pre-key counts when possible
This commit is contained in:
@@ -9,6 +9,12 @@ import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.whispersystems.textsecuregcm.entities.ECPreKey;
|
||||
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
|
||||
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
|
||||
import software.amazon.awssdk.services.dynamodb.model.ScanResponse;
|
||||
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
|
||||
import software.amazon.awssdk.services.dynamodb.paginators.ScanIterable;
|
||||
import java.util.Map;
|
||||
|
||||
class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<ECPreKey> {
|
||||
|
||||
@@ -32,4 +38,24 @@ class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<ECPreKey> {
|
||||
protected ECPreKey generatePreKey(final long keyId) {
|
||||
return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearKeyCountAttributes() {
|
||||
final ScanIterable scanIterable = DYNAMO_DB_EXTENSION.getDynamoDbClient().scanPaginator(ScanRequest.builder()
|
||||
.tableName(DynamoDbExtensionSchema.Tables.EC_KEYS.tableName())
|
||||
.build());
|
||||
|
||||
for (final ScanResponse response : scanIterable) {
|
||||
for (final Map<String, AttributeValue> item : response.items()) {
|
||||
|
||||
DYNAMO_DB_EXTENSION.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
|
||||
.tableName(DynamoDbExtensionSchema.Tables.EC_KEYS.tableName())
|
||||
.key(Map.of(
|
||||
SingleUsePreKeyStore.KEY_ACCOUNT_UUID, item.get(SingleUsePreKeyStore.KEY_ACCOUNT_UUID),
|
||||
SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID, item.get(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID)))
|
||||
.updateExpression("REMOVE " + SingleUsePreKeyStore.ATTR_REMAINING_KEYS)
|
||||
.build());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,12 @@ import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
||||
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
|
||||
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
|
||||
import software.amazon.awssdk.services.dynamodb.model.ScanResponse;
|
||||
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
|
||||
import software.amazon.awssdk.services.dynamodb.paginators.ScanIterable;
|
||||
import java.util.Map;
|
||||
|
||||
class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<KEMSignedPreKey> {
|
||||
|
||||
@@ -36,4 +42,24 @@ class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<KEMSignedPreK
|
||||
protected KEMSignedPreKey generatePreKey(final long keyId) {
|
||||
return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearKeyCountAttributes() {
|
||||
final ScanIterable scanIterable = DYNAMO_DB_EXTENSION.getDynamoDbClient().scanPaginator(ScanRequest.builder()
|
||||
.tableName(DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName())
|
||||
.build());
|
||||
|
||||
for (final ScanResponse response : scanIterable) {
|
||||
for (final Map<String, AttributeValue> item : response.items()) {
|
||||
|
||||
DYNAMO_DB_EXTENSION.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
|
||||
.tableName(DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName())
|
||||
.key(Map.of(
|
||||
SingleUsePreKeyStore.KEY_ACCOUNT_UUID, item.get(SingleUsePreKeyStore.KEY_ACCOUNT_UUID),
|
||||
SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID, item.get(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID)))
|
||||
.updateExpression("REMOVE " + SingleUsePreKeyStore.ATTR_REMAINING_KEYS)
|
||||
.build());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,10 +9,16 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.whispersystems.textsecuregcm.entities.PreKey;
|
||||
|
||||
abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||
@@ -23,6 +29,8 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||
|
||||
protected abstract K generatePreKey(final long keyId);
|
||||
|
||||
protected abstract void clearKeyCountAttributes();
|
||||
|
||||
@Test
|
||||
void storeTake() {
|
||||
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
||||
@@ -32,20 +40,22 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||
|
||||
assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join());
|
||||
|
||||
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
|
||||
final List<K> sortedPreKeys;
|
||||
{
|
||||
final List<K> preKeys = generateRandomPreKeys();
|
||||
assertDoesNotThrow(() -> preKeyStore.store(accountIdentifier, deviceId, preKeys).join());
|
||||
|
||||
for (int i = 0; i < KEY_COUNT; i++) {
|
||||
preKeys.add(generatePreKey(i));
|
||||
sortedPreKeys = new ArrayList<>(preKeys);
|
||||
sortedPreKeys.sort(Comparator.comparing(preKey -> preKey.keyId()));
|
||||
}
|
||||
|
||||
assertDoesNotThrow(() -> preKeyStore.store(accountIdentifier, deviceId, preKeys).join());
|
||||
|
||||
assertEquals(Optional.of(preKeys.get(0)), preKeyStore.take(accountIdentifier, deviceId).join());
|
||||
assertEquals(Optional.of(preKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join());
|
||||
assertEquals(Optional.of(sortedPreKeys.get(0)), preKeyStore.take(accountIdentifier, deviceId).join());
|
||||
assertEquals(Optional.of(sortedPreKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join());
|
||||
}
|
||||
|
||||
@Test
|
||||
void getCount() {
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = {true, false})
|
||||
void getCount(final boolean hasKeyCountAttribute) {
|
||||
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
@@ -53,15 +63,72 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||
|
||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||
|
||||
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
|
||||
|
||||
for (int i = 0; i < KEY_COUNT; i++) {
|
||||
preKeys.add(generatePreKey(i));
|
||||
}
|
||||
final List<K> preKeys = generateRandomPreKeys();
|
||||
|
||||
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||
|
||||
if (!hasKeyCountAttribute) {
|
||||
clearKeyCountAttributes();
|
||||
}
|
||||
|
||||
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||
|
||||
for (int i = 0; i < KEY_COUNT; i++) {
|
||||
preKeyStore.take(accountIdentifier, deviceId).join();
|
||||
assertEquals(KEY_COUNT - (i + 1), preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void peekCount() {
|
||||
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = 1;
|
||||
|
||||
assertEquals(Optional.of(0), preKeyStore.peekCount(accountIdentifier, deviceId).join());
|
||||
|
||||
final List<K> preKeys = generateRandomPreKeys();
|
||||
|
||||
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||
|
||||
assertEquals(Optional.of(KEY_COUNT), preKeyStore.peekCount(accountIdentifier, deviceId).join());
|
||||
|
||||
for (int i = 0; i < KEY_COUNT; i++) {
|
||||
preKeyStore.take(accountIdentifier, deviceId).join();
|
||||
assertEquals(Optional.of(KEY_COUNT - (i + 1)), preKeyStore.peekCount(accountIdentifier, deviceId).join());
|
||||
}
|
||||
|
||||
preKeyStore.store(accountIdentifier, deviceId, List.of(generatePreKey(KEY_COUNT + 1))).join();
|
||||
clearKeyCountAttributes();
|
||||
|
||||
assertEquals(Optional.empty(), preKeyStore.peekCount(accountIdentifier, deviceId).join());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = {true, false})
|
||||
void scanCount(final boolean hasKeyCountAttribute) {
|
||||
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = 1;
|
||||
|
||||
assertEquals(0, preKeyStore.scanCount(accountIdentifier, deviceId).join());
|
||||
|
||||
final List<K> preKeys = generateRandomPreKeys();
|
||||
|
||||
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||
|
||||
if (!hasKeyCountAttribute) {
|
||||
clearKeyCountAttributes();
|
||||
}
|
||||
|
||||
assertEquals(KEY_COUNT, preKeyStore.scanCount(accountIdentifier, deviceId).join());
|
||||
|
||||
for (int i = 0; i < KEY_COUNT; i++) {
|
||||
preKeyStore.take(accountIdentifier, deviceId).join();
|
||||
assertEquals(KEY_COUNT - (i + 1), preKeyStore.scanCount(accountIdentifier, deviceId).join());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -74,11 +141,7 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join());
|
||||
|
||||
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
|
||||
|
||||
for (int i = 0; i < KEY_COUNT; i++) {
|
||||
preKeys.add(generatePreKey(i));
|
||||
}
|
||||
final List<K> preKeys = generateRandomPreKeys();
|
||||
|
||||
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||
preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
|
||||
@@ -99,11 +162,7 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join());
|
||||
|
||||
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
|
||||
|
||||
for (int i = 0; i < KEY_COUNT; i++) {
|
||||
preKeys.add(generatePreKey(i));
|
||||
}
|
||||
final List<K> preKeys = generateRandomPreKeys();
|
||||
|
||||
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||
preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
|
||||
@@ -113,4 +172,16 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
|
||||
}
|
||||
|
||||
private List<K> generateRandomPreKeys() {
|
||||
final Set<Integer> keyIds = new HashSet<>(KEY_COUNT);
|
||||
|
||||
while (keyIds.size() < KEY_COUNT) {
|
||||
keyIds.add(Math.abs(ThreadLocalRandom.current().nextInt()));
|
||||
}
|
||||
|
||||
return keyIds.stream()
|
||||
.map(this::generatePreKey)
|
||||
.toList();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user