Validate pre-key key-id ranges

This commit is contained in:
ravi-signal
2026-03-12 16:37:28 -05:00
committed by GitHub
parent ac23b8e79e
commit b7d455ed11
16 changed files with 179 additions and 36 deletions

View File

@@ -8,7 +8,10 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter;
public record ECPreKey(
@@ -16,6 +19,8 @@ public record ECPreKey(
An arbitrary ID for this key, which will be provided by peers using this key to encrypt messages so the private key can be looked up.
Should not be zero. Should be less than 2^24.
""")
@Max(KeyIdUtil.MAX_KEY_ID)
@Min(KeyIdUtil.MIN_KEY_ID)
long keyId,
@JsonSerialize(using = ECPublicKeyAdapter.Serializer.class)

View File

@@ -8,7 +8,10 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter;
import java.util.Arrays;
@@ -19,6 +22,8 @@ public record ECSignedPreKey(
An arbitrary ID for this key, which will be provided by peers using this key to encrypt messages so the private key can be looked up.
Should not be zero. Should be less than 2^24.
""")
@Max(KeyIdUtil.MAX_KEY_ID)
@Min(KeyIdUtil.MIN_KEY_ID)
long keyId,
@JsonSerialize(using = ECPublicKeyAdapter.Serializer.class)

View File

@@ -8,7 +8,10 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.KEMPublicKeyAdapter;
import java.util.Arrays;
@@ -20,6 +23,8 @@ public record KEMSignedPreKey(
Should not be zero. Should be less than 2^24. The owner of this key must be able to determine from the key ID whether this represents
a single-use or last-resort key, but another party should *not* be able to tell.
""")
@Max(KeyIdUtil.MAX_KEY_ID)
@Min(KeyIdUtil.MIN_KEY_ID)
long keyId,
@JsonSerialize(using = KEMPublicKeyAdapter.Serializer.class)

View File

@@ -19,6 +19,7 @@ import org.signal.chat.keys.DevicePreKeyBundle;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
import org.whispersystems.textsecuregcm.storage.KeysManager;
class KeysGrpcHelper {
@@ -67,19 +68,20 @@ class KeysGrpcHelper {
preKeysByDeviceId.forEach((deviceId, devicePreKeys) -> {
final Device device = targetAccount.getDevice(deviceId).orElseThrow();
final DevicePreKeyBundle.Builder builder = DevicePreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(devicePreKeys.ecSignedPreKey().keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(devicePreKeys.ecSignedPreKey().keyId()))
.setPublicKey(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().serializedPublicKey()))
.setSignature(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().signature())))
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
.setKeyId(devicePreKeys.kemSignedPreKey().keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(devicePreKeys.kemSignedPreKey().keyId()))
.setPublicKey(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().serializedPublicKey()))
.setSignature(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().signature())))
.setRegistrationId(device.getRegistrationId(targetServiceIdentifier.identityType()));
devicePreKeys.ecPreKey().ifPresent(ecPreKey -> builder.setEcOneTimePreKey(EcPreKey.newBuilder()
.setKeyId(ecPreKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(ecPreKey.keyId()))
.setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))));
preKeyBundlesBuilder.putDevicePreKeys(deviceId, builder.build());

View File

@@ -0,0 +1,29 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class KeyIdUtil {
public static final long MAX_KEY_ID = (1L << 32) - 1;
public static final long MIN_KEY_ID = 0;
private KeyIdUtil(){}
public static boolean keyIdValid(final long keyId) {
return keyId <= MAX_KEY_ID && keyId >= MIN_KEY_ID;
}
/// Convert a long keyId (a 32-bit unsigned int) into an int representation.
///
/// The inverse of [Integer#toUnsignedLong].
///
/// @param keyId A key ID which must be in the range [0, 2^32)
/// @throws IllegalArgumentException If `keyId` is not within the range
/// @return A 32-bit unsigned integer where the top bit is stored in the sign bit
public static int toUnsignedInt(final long keyId) {
if (!keyIdValid(keyId)) {
throw new IllegalArgumentException("Invalid keyId " + keyId);
}
return (int) keyId;
}
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
@@ -71,6 +72,7 @@ public class PagedSingleUseKEMPreKeyStore {
private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice"));
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
private final Counter outOfRangeKeysDiscarded = Metrics.counter(name(getClass(), "outOfRangeKeysDiscarded"));
final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary
.builder(name(getClass(), "availableKeyCount"))
.register(Metrics.globalRegistry);
@@ -159,7 +161,14 @@ public class PagedSingleUseKEMPreKeyStore {
*/
public CompletableFuture<Optional<KEMSignedPreKey>> take(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return takeHelper(identifier, deviceId)
.whenComplete((maybeKey, throwable) ->
sample.stop(Metrics.timer(
takeKeyTimerName,
KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent()))));
}
private CompletableFuture<Optional<KEMSignedPreKey>> takeHelper(final UUID identifier, final byte deviceId) {
return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder()
.tableName(tableName)
.key(Map.of(
@@ -196,10 +205,15 @@ public class PagedSingleUseKEMPreKeyStore {
.exceptionally(ExceptionUtils.exceptionallyHandler(
ConditionalCheckFailedException.class,
e -> Optional.empty()))
.whenComplete((maybeKey, throwable) ->
sample.stop(Metrics.timer(
takeKeyTimerName,
KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent()))));
.thenCompose(maybeKey -> {
if (!maybeKey.map(KEMSignedPreKey::keyId).map(KeyIdUtil::keyIdValid).orElse(true)) {
// At some point we did not validate that keyIds fit in an unsigned 32-bit integer, which clients require.
// This keyId was invalid, so just recursively fetch the next key
outOfRangeKeysDiscarded.increment();
return takeHelper(identifier, deviceId);
}
return CompletableFuture.completedFuture(maybeKey);
});
}
/**

View File

@@ -107,11 +107,14 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
.build())
.thenApply(response -> response.hasItem() ? Optional.of(getPreKeyFromItem(response.item())) : Optional.empty());
findFuture.whenComplete((maybeSignedPreKey, throwable) ->
sample.stop(Metrics.timer(findKeyTimerName,
"keyPresent", String.valueOf(maybeSignedPreKey != null && maybeSignedPreKey.isPresent()))));
return findFuture.whenComplete((maybeSignedPreKey, throwable) -> {
if (throwable == null && maybeSignedPreKey.map(k -> !KeyIdUtil.keyIdValid(k.keyId())).orElse(false)) {
throw new IllegalStateException("Encountered an impossible invalid repeated use pre-key id of " + maybeSignedPreKey.get().keyId());
}
return findFuture;
sample.stop(Metrics.timer(findKeyTimerName,
"keyPresent", String.valueOf(maybeSignedPreKey != null && maybeSignedPreKey.isPresent())));
});
}
protected static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final byte deviceId) {

View File

@@ -47,7 +47,6 @@ import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
* may fall back to using the device's repeated-use ("last-resort") signed pre-key instead.
*/
public class SingleUseECPreKeyStore {
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName;
@@ -58,12 +57,14 @@ public class SingleUseECPreKeyStore {
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
private final Counter noKeyCountAvailableCounter = Metrics.counter(name(getClass(), "noKeyCountAvailable"));
private final Counter outOfRangeKeysDiscarded =
Metrics.counter(name(getClass(), "outOfRangeKeysDiscarded"));
final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary
.builder(name(getClass(), "keysConsideredForTake"))
.distributionStatisticExpiry(Duration.ofMinutes(10))
.register(Metrics.globalRegistry);
final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary
.builder(name(getClass(), "availableKeyCount"))
.distributionStatisticExpiry(Duration.ofMinutes(10))
@@ -135,7 +136,7 @@ public class SingleUseECPreKeyStore {
public CompletableFuture<Optional<ECPreKey>> take(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
final AttributeValue partitionKey = getPartitionKey(identifier);
final AtomicInteger keysConsidered = new AtomicInteger(0);
final AtomicInteger deletionAttempts = new AtomicInteger(0);
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
@@ -156,16 +157,26 @@ public class SingleUseECPreKeyStore {
KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID)))
.returnValues(ReturnValue.ALL_OLD)
.build())
.flatMap(deleteItemRequest -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(deleteItemRequest)), 1)
.doOnNext(deleteItemResponse -> keysConsidered.incrementAndGet())
.concatMap(deleteItemRequest -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(deleteItemRequest)))
.doOnNext(_ -> deletionAttempts.incrementAndGet())
.filter(DeleteItemResponse::hasAttributes)
.filter(item -> {
final long keyId = getKeyIdFromItem(item.attributes());
final boolean keyIdValid = KeyIdUtil.keyIdValid(keyId);
if (!keyIdValid) {
outOfRangeKeysDiscarded.increment();
}
// At some point we did not validate that keyIds fit in an unsigned 32-bit integer, which clients require.
// If this keyId is invalid, we'll skip it and fetch the next key
return keyIdValid;
})
.next()
.map(deleteItemResponse -> getPreKeyFromItem(deleteItemResponse.attributes()))
.toFuture()
.thenApply(Optional::ofNullable)
.whenComplete((maybeKey, throwable) -> {
sample.stop(Metrics.timer(takeKeyTimerName, KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent())));
keysConsideredForTakeDistributionSummary.record(keysConsidered.get());
keysConsideredForTakeDistributionSummary.record(deletionAttempts.get());
});
}
@@ -310,7 +321,7 @@ public class SingleUseECPreKeyStore {
}
private ECPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final long keyId = getKeyIdFromItem(item);
final byte[] publicKey = AttributeValues.extractByteArray(item.get(ATTR_PUBLIC_KEY), PARSE_BYTE_ARRAY_COUNTER_NAME);
try {
@@ -320,4 +331,8 @@ public class SingleUseECPreKeyStore {
throw new IllegalArgumentException(e);
}
}
private static long getKeyIdFromItem(final Map<String, AttributeValue> item) {
return item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
}
}

View File

@@ -41,7 +41,7 @@ message EcPreKey {
// A locally-unique identifier for this key, which will be provided by
// peers using this key to encrypt messages so the private key can be looked
// up.
uint64 key_id = 1;
uint32 key_id = 1;
// The public key, serialized in libsignal's elliptic-curve public key format.
bytes public_key = 2 [(require.nonEmpty) = true];
@@ -51,7 +51,7 @@ message EcSignedPreKey {
// A locally-unique identifier for this key, which will be provided by
// peers using this key to encrypt messages so the private key can be looked
// up.
uint64 key_id = 1;
uint32 key_id = 1;
// The public key, serialized in libsignal's elliptic-curve public key format.
bytes public_key = 2 [(require.nonEmpty) = true];
@@ -64,7 +64,7 @@ message EcSignedPreKey {
message KemSignedPreKey {
// An locally-unique identifier for this key, which will be provided by peers
// using this key to encrypt messages so the private key can be looked up.
uint64 key_id = 1;
uint32 key_id = 1;
// The public key, serialized in libsignal's Kyber1024 public key format.
bytes public_key = 2 [(require.nonEmpty) = true];

View File

@@ -66,6 +66,7 @@ import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
@@ -508,14 +509,14 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<KeysAnonymousGrpcS
private static EcPreKey toGrpcEcPreKey(final ECPreKey preKey) {
return EcPreKey.newBuilder()
.setKeyId(preKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId()))
.setPublicKey(ByteString.copyFrom(preKey.publicKey().serialize()))
.build();
}
private static EcSignedPreKey toGrpcEcSignedPreKey(final ECSignedPreKey preKey) {
return EcSignedPreKey.newBuilder()
.setKeyId(preKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId()))
.setPublicKey(ByteString.copyFrom(preKey.publicKey().serialize()))
.setSignature(ByteString.copyFrom(preKey.signature()))
.build();
@@ -523,7 +524,7 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<KeysAnonymousGrpcS
private static KemSignedPreKey toGrpcKemSignedPreKey(final KEMSignedPreKey preKey) {
return KemSignedPreKey.newBuilder()
.setKeyId(preKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId()))
.setPublicKey(ByteString.copyFrom(preKey.publicKey().serialize()))
.setSignature(ByteString.copyFrom(preKey.signature()))
.build();

View File

@@ -68,6 +68,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyIdUtil;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
@@ -161,7 +162,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
.setIdentityType(identityType)
.addAllPreKeys(preKeys.stream()
.map(preKey -> EcPreKey.newBuilder()
.setKeyId(preKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId()))
.setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey()))
.build())
.toList())
@@ -233,7 +234,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
.setIdentityType(identityType)
.addAllPreKeys(preKeys.stream()
.map(preKey -> KemSignedPreKey.newBuilder()
.setKeyId(preKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId()))
.setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(preKey.signature()))
.build())
@@ -309,7 +310,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
authenticatedServiceStub().setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
.setIdentityType(identityType)
.setSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(signedPreKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(signedPreKey.keyId()))
.setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(signedPreKey.signature()))
.build())
@@ -339,7 +340,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
final SetEcSignedPreKeyRequest prototypeRequest = SetEcSignedPreKeyRequest.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(signedPreKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(signedPreKey.keyId()))
.setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(signedPreKey.signature()))
.build())
@@ -388,7 +389,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
authenticatedServiceStub().setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
.setIdentityType(identityType)
.setSignedPreKey(KemSignedPreKey.newBuilder()
.setKeyId(lastResortPreKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(lastResortPreKey.keyId()))
.setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(lastResortPreKey.signature()))
.build())
@@ -415,7 +416,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
final SetKemLastResortPreKeyRequest prototypeRequest = SetKemLastResortPreKeyRequest.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setSignedPreKey(KemSignedPreKey.newBuilder()
.setKeyId(lastResortPreKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(lastResortPreKey.keyId()))
.setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(lastResortPreKey.signature()))
.build())
@@ -493,19 +494,19 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
final DevicePreKeyBundle.Builder builder = DevicePreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(ecSignedPreKey.keyId()))
.setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKey.signature()))
.build())
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
.setKeyId(kemSignedPreKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(kemSignedPreKey.keyId()))
.setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemSignedPreKey.signature()))
.build())
.setRegistrationId(entry.getValue());
maybeEcPreKey.ifPresent(ecPreKey -> builder
.setEcOneTimePreKey(EcPreKey.newBuilder()
.setKeyId(ecPreKey.keyId())
.setKeyId(KeyIdUtil.toUnsignedInt(ecPreKey.keyId()))
.setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))
.build()));
expectedPreKeyBundles.put(entry.getKey(), builder.build());

View File

@@ -253,6 +253,22 @@ class PagedSingleUseKEMPreKeyStoreTest {
.block();
}
@Test
void takeSkipsOutOfRangeKeys() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
final KEMSignedPreKey validKey = KeysHelper.signedKEMPreKey(1, IDENTITY_KEY_PAIR);
final KEMSignedPreKey outOfRange1 = KeysHelper.signedKEMPreKey(KeyIdUtil.MAX_KEY_ID + 1, IDENTITY_KEY_PAIR);
final KEMSignedPreKey outOfRange2 = KeysHelper.signedKEMPreKey(KeyIdUtil.MAX_KEY_ID + 2, IDENTITY_KEY_PAIR);
keyStore.store(accountIdentifier, deviceId, List.of(outOfRange1, outOfRange2, validKey)).join();
assertEquals(Optional.of(validKey), keyStore.take(accountIdentifier, deviceId).join());
assertEquals(Optional.empty(), keyStore.take(accountIdentifier, deviceId).join());
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
}
private List<KEMSignedPreKey> generateRandomPreKeys() {
final Set<Integer> keyIds = new HashSet<>(KEY_COUNT);

View File

@@ -37,11 +37,17 @@ class RepeatedUseECSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTes
@Override
protected ECSignedPreKey generateSignedPreKey() {
return KeysHelper.signedECPreKey(currentKeyId++, IDENTITY_KEY_PAIR);
return generateSignedPreKey(currentKeyId++);
}
@Override
protected ECSignedPreKey generateSignedPreKey(long keyId) {
return KeysHelper.signedECPreKey(keyId, IDENTITY_KEY_PAIR);
}
@Override
protected DynamoDbClient getDynamoDbClient() {
return DYNAMO_DB_EXTENSION.getDynamoDbClient();
}
}

View File

@@ -42,6 +42,11 @@ class RepeatedUseKEMSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTe
@Override
protected KEMSignedPreKey generateSignedPreKey() {
return KeysHelper.signedKEMPreKey(currentKeyId++, IDENTITY_KEY_PAIR);
return generateSignedPreKey(currentKeyId++);
}
@Override
protected KEMSignedPreKey generateSignedPreKey(long keyId) {
return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
}
}

View File

@@ -8,11 +8,11 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
@@ -22,6 +22,8 @@ abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
protected abstract K generateSignedPreKey();
protected abstract K generateSignedPreKey(long keyId);
protected abstract DynamoDbClient getDynamoDbClient();
@Test
@@ -72,4 +74,16 @@ abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.of(retainedPreKey), keys.find(identifier, deviceId2).join());
}
@Test
void findThrowsOnOutOfRangeKeyId() {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
final UUID identifier = UUID.randomUUID();
final byte deviceId = 1;
final K outOfRangeKey = generateSignedPreKey(KeyIdUtil.MAX_KEY_ID + 1);
keys.store(identifier, deviceId, outOfRangeKey).join();
CompletableFutureTestUtil.assertFailsWithCause(IllegalStateException.class, keys.find(identifier, deviceId));
}
}

View File

@@ -155,6 +155,28 @@ class SingleUseECPreKeyStoreTest {
assertEquals(0, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
}
@Test
void takeSkipsOutOfRangeKeys() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
final long outOfRange1 = KeyIdUtil.MAX_KEY_ID + 1;
final long outOfRange2 = KeyIdUtil.MAX_KEY_ID + 2;
final long validKeyId = 1;
final List<ECPreKey> preKeys = List.of(
generatePreKey(outOfRange1),
generatePreKey(outOfRange2),
generatePreKey(validKeyId));
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
final Optional<ECPreKey> taken = preKeyStore.take(accountIdentifier, deviceId).join();
assertEquals(Optional.of(preKeys.get(2)), taken);
assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
}
private List<ECPreKey> generateRandomPreKeys() {
final Set<Integer> keyIds = new HashSet<>(KEY_COUNT);