diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECPreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECPreKey.java index e638183d6..a52f5c64c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECPreKey.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECPreKey.java @@ -8,7 +8,10 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.storage.KeyIdUtil; import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter; public record ECPreKey( @@ -16,6 +19,8 @@ public record ECPreKey( An arbitrary ID for this key, which will be provided by peers using this key to encrypt messages so the private key can be looked up. Should not be zero. Should be less than 2^24. """) + @Max(KeyIdUtil.MAX_KEY_ID) + @Min(KeyIdUtil.MIN_KEY_ID) long keyId, @JsonSerialize(using = ECPublicKeyAdapter.Serializer.class) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECSignedPreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECSignedPreKey.java index b585286bc..1ec4f9b49 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECSignedPreKey.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECSignedPreKey.java @@ -8,7 +8,10 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.storage.KeyIdUtil; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter; import java.util.Arrays; @@ -19,6 +22,8 @@ public record ECSignedPreKey( An arbitrary ID for this key, which will be provided by peers using this key to encrypt messages so the private key can be looked up. Should not be zero. Should be less than 2^24. """) + @Max(KeyIdUtil.MAX_KEY_ID) + @Min(KeyIdUtil.MIN_KEY_ID) long keyId, @JsonSerialize(using = ECPublicKeyAdapter.Serializer.class) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/KEMSignedPreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KEMSignedPreKey.java index 5e2ff6eca..cb10e0523 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/KEMSignedPreKey.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KEMSignedPreKey.java @@ -8,7 +8,10 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; import org.signal.libsignal.protocol.kem.KEMPublicKey; +import org.whispersystems.textsecuregcm.storage.KeyIdUtil; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.KEMPublicKeyAdapter; import java.util.Arrays; @@ -20,6 +23,8 @@ public record KEMSignedPreKey( Should not be zero. Should be less than 2^24. The owner of this key must be able to determine from the key ID whether this represents a single-use or last-resort key, but another party should *not* be able to tell. """) + @Max(KeyIdUtil.MAX_KEY_ID) + @Min(KeyIdUtil.MIN_KEY_ID) long keyId, @JsonSerialize(using = KEMPublicKeyAdapter.Serializer.class) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java index 4fed1e190..9d1793c1f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java @@ -19,6 +19,7 @@ import org.signal.chat.keys.DevicePreKeyBundle; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeyIdUtil; import org.whispersystems.textsecuregcm.storage.KeysManager; class KeysGrpcHelper { @@ -67,19 +68,20 @@ class KeysGrpcHelper { preKeysByDeviceId.forEach((deviceId, devicePreKeys) -> { final Device device = targetAccount.getDevice(deviceId).orElseThrow(); + final DevicePreKeyBundle.Builder builder = DevicePreKeyBundle.newBuilder() .setEcSignedPreKey(EcSignedPreKey.newBuilder() - .setKeyId(devicePreKeys.ecSignedPreKey().keyId()) + .setKeyId(KeyIdUtil.toUnsignedInt(devicePreKeys.ecSignedPreKey().keyId())) .setPublicKey(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().serializedPublicKey())) .setSignature(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().signature()))) .setKemOneTimePreKey(KemSignedPreKey.newBuilder() - .setKeyId(devicePreKeys.kemSignedPreKey().keyId()) + .setKeyId(KeyIdUtil.toUnsignedInt(devicePreKeys.kemSignedPreKey().keyId())) .setPublicKey(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().serializedPublicKey())) .setSignature(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().signature()))) .setRegistrationId(device.getRegistrationId(targetServiceIdentifier.identityType())); devicePreKeys.ecPreKey().ifPresent(ecPreKey -> builder.setEcOneTimePreKey(EcPreKey.newBuilder() - .setKeyId(ecPreKey.keyId()) + .setKeyId(KeyIdUtil.toUnsignedInt(ecPreKey.keyId())) .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())))); preKeyBundlesBuilder.putDevicePreKeys(deviceId, builder.build()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyIdUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyIdUtil.java new file mode 100644 index 000000000..d83ed5767 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyIdUtil.java @@ -0,0 +1,29 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.storage; + +public class KeyIdUtil { + public static final long MAX_KEY_ID = (1L << 32) - 1; + public static final long MIN_KEY_ID = 0; + private KeyIdUtil(){} + + public static boolean keyIdValid(final long keyId) { + return keyId <= MAX_KEY_ID && keyId >= MIN_KEY_ID; + } + + /// Convert a long keyId (a 32-bit unsigned int) into an int representation. + /// + /// The inverse of [Integer#toUnsignedLong]. + /// + /// @param keyId A key ID which must be in the range [0, 2^32) + /// @throws IllegalArgumentException If `keyId` is not within the range + /// @return A 32-bit unsigned integer where the top bit is stored in the sign bit + public static int toUnsignedInt(final long keyId) { + if (!keyIdValid(keyId)) { + throw new IllegalArgumentException("Invalid keyId " + keyId); + } + return (int) keyId; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java index 7e896dd0d..da11aeedb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; +import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; @@ -71,6 +72,7 @@ public class PagedSingleUseKEMPreKeyStore { private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice")); private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount")); + private final Counter outOfRangeKeysDiscarded = Metrics.counter(name(getClass(), "outOfRangeKeysDiscarded")); final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary .builder(name(getClass(), "availableKeyCount")) .register(Metrics.globalRegistry); @@ -159,7 +161,14 @@ public class PagedSingleUseKEMPreKeyStore { */ public CompletableFuture> take(final UUID identifier, final byte deviceId) { final Timer.Sample sample = Timer.start(); + return takeHelper(identifier, deviceId) + .whenComplete((maybeKey, throwable) -> + sample.stop(Metrics.timer( + takeKeyTimerName, + KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent())))); + } + private CompletableFuture> takeHelper(final UUID identifier, final byte deviceId) { return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder() .tableName(tableName) .key(Map.of( @@ -196,10 +205,15 @@ public class PagedSingleUseKEMPreKeyStore { .exceptionally(ExceptionUtils.exceptionallyHandler( ConditionalCheckFailedException.class, e -> Optional.empty())) - .whenComplete((maybeKey, throwable) -> - sample.stop(Metrics.timer( - takeKeyTimerName, - KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent())))); + .thenCompose(maybeKey -> { + if (!maybeKey.map(KEMSignedPreKey::keyId).map(KeyIdUtil::keyIdValid).orElse(true)) { + // At some point we did not validate that keyIds fit in an unsigned 32-bit integer, which clients require. + // This keyId was invalid, so just recursively fetch the next key + outOfRangeKeysDiscarded.increment(); + return takeHelper(identifier, deviceId); + } + return CompletableFuture.completedFuture(maybeKey); + }); } /** diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java index c51133ead..6f2a64d39 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java @@ -107,11 +107,14 @@ public abstract class RepeatedUseSignedPreKeyStore> { .build()) .thenApply(response -> response.hasItem() ? Optional.of(getPreKeyFromItem(response.item())) : Optional.empty()); - findFuture.whenComplete((maybeSignedPreKey, throwable) -> - sample.stop(Metrics.timer(findKeyTimerName, - "keyPresent", String.valueOf(maybeSignedPreKey != null && maybeSignedPreKey.isPresent())))); + return findFuture.whenComplete((maybeSignedPreKey, throwable) -> { + if (throwable == null && maybeSignedPreKey.map(k -> !KeyIdUtil.keyIdValid(k.keyId())).orElse(false)) { + throw new IllegalStateException("Encountered an impossible invalid repeated use pre-key id of " + maybeSignedPreKey.get().keyId()); + } - return findFuture; + sample.stop(Metrics.timer(findKeyTimerName, + "keyPresent", String.valueOf(maybeSignedPreKey != null && maybeSignedPreKey.isPresent()))); + }); } protected static Map getPrimaryKey(final UUID identifier, final byte deviceId) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java index 36f97a185..a3f1e24ab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java @@ -47,7 +47,6 @@ import software.amazon.awssdk.services.dynamodb.model.ReturnValue; * may fall back to using the device's repeated-use ("last-resort") signed pre-key instead. */ public class SingleUseECPreKeyStore { - private final DynamoDbAsyncClient dynamoDbAsyncClient; private final String tableName; @@ -58,12 +57,14 @@ public class SingleUseECPreKeyStore { private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount")); private final Counter noKeyCountAvailableCounter = Metrics.counter(name(getClass(), "noKeyCountAvailable")); - + private final Counter outOfRangeKeysDiscarded = + Metrics.counter(name(getClass(), "outOfRangeKeysDiscarded")); final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary .builder(name(getClass(), "keysConsideredForTake")) .distributionStatisticExpiry(Duration.ofMinutes(10)) .register(Metrics.globalRegistry); + final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary .builder(name(getClass(), "availableKeyCount")) .distributionStatisticExpiry(Duration.ofMinutes(10)) @@ -135,7 +136,7 @@ public class SingleUseECPreKeyStore { public CompletableFuture> take(final UUID identifier, final byte deviceId) { final Timer.Sample sample = Timer.start(); final AttributeValue partitionKey = getPartitionKey(identifier); - final AtomicInteger keysConsidered = new AtomicInteger(0); + final AtomicInteger deletionAttempts = new AtomicInteger(0); return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() .tableName(tableName) @@ -156,16 +157,26 @@ public class SingleUseECPreKeyStore { KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID))) .returnValues(ReturnValue.ALL_OLD) .build()) - .flatMap(deleteItemRequest -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(deleteItemRequest)), 1) - .doOnNext(deleteItemResponse -> keysConsidered.incrementAndGet()) + .concatMap(deleteItemRequest -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(deleteItemRequest))) + .doOnNext(_ -> deletionAttempts.incrementAndGet()) .filter(DeleteItemResponse::hasAttributes) + .filter(item -> { + final long keyId = getKeyIdFromItem(item.attributes()); + final boolean keyIdValid = KeyIdUtil.keyIdValid(keyId); + if (!keyIdValid) { + outOfRangeKeysDiscarded.increment(); + } + // At some point we did not validate that keyIds fit in an unsigned 32-bit integer, which clients require. + // If this keyId is invalid, we'll skip it and fetch the next key + return keyIdValid; + }) .next() .map(deleteItemResponse -> getPreKeyFromItem(deleteItemResponse.attributes())) .toFuture() .thenApply(Optional::ofNullable) .whenComplete((maybeKey, throwable) -> { sample.stop(Metrics.timer(takeKeyTimerName, KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent()))); - keysConsideredForTakeDistributionSummary.record(keysConsidered.get()); + keysConsideredForTakeDistributionSummary.record(deletionAttempts.get()); }); } @@ -310,7 +321,7 @@ public class SingleUseECPreKeyStore { } private ECPreKey getPreKeyFromItem(final Map item) { - final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); + final long keyId = getKeyIdFromItem(item); final byte[] publicKey = AttributeValues.extractByteArray(item.get(ATTR_PUBLIC_KEY), PARSE_BYTE_ARRAY_COUNTER_NAME); try { @@ -320,4 +331,8 @@ public class SingleUseECPreKeyStore { throw new IllegalArgumentException(e); } } + + private static long getKeyIdFromItem(final Map item) { + return item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); + } } diff --git a/service/src/main/proto/org/signal/chat/common.proto b/service/src/main/proto/org/signal/chat/common.proto index d0ba1a1d6..0b924f430 100644 --- a/service/src/main/proto/org/signal/chat/common.proto +++ b/service/src/main/proto/org/signal/chat/common.proto @@ -41,7 +41,7 @@ message EcPreKey { // A locally-unique identifier for this key, which will be provided by // peers using this key to encrypt messages so the private key can be looked // up. - uint64 key_id = 1; + uint32 key_id = 1; // The public key, serialized in libsignal's elliptic-curve public key format. bytes public_key = 2 [(require.nonEmpty) = true]; @@ -51,7 +51,7 @@ message EcSignedPreKey { // A locally-unique identifier for this key, which will be provided by // peers using this key to encrypt messages so the private key can be looked // up. - uint64 key_id = 1; + uint32 key_id = 1; // The public key, serialized in libsignal's elliptic-curve public key format. bytes public_key = 2 [(require.nonEmpty) = true]; @@ -64,7 +64,7 @@ message EcSignedPreKey { message KemSignedPreKey { // An locally-unique identifier for this key, which will be provided by peers // using this key to encrypt messages so the private key can be looked up. - uint64 key_id = 1; + uint32 key_id = 1; // The public key, serialized in libsignal's Kyber1024 public key format. bytes public_key = 2 [(require.nonEmpty) = true]; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java index 272efcdb5..530a8b820 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java @@ -66,6 +66,7 @@ import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeyIdUtil; import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; @@ -508,14 +509,14 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest EcPreKey.newBuilder() - .setKeyId(preKey.keyId()) + .setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId())) .setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey())) .build()) .toList()) @@ -233,7 +234,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest KemSignedPreKey.newBuilder() - .setKeyId(preKey.keyId()) + .setKeyId(KeyIdUtil.toUnsignedInt(preKey.keyId())) .setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey())) .setSignature(ByteString.copyFrom(preKey.signature())) .build()) @@ -309,7 +310,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest builder .setEcOneTimePreKey(EcPreKey.newBuilder() - .setKeyId(ecPreKey.keyId()) + .setKeyId(KeyIdUtil.toUnsignedInt(ecPreKey.keyId())) .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) .build())); expectedPreKeyBundles.put(entry.getKey(), builder.build()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStoreTest.java index b1a76bd4d..245ae800a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStoreTest.java @@ -253,6 +253,22 @@ class PagedSingleUseKEMPreKeyStoreTest { .block(); } + @Test + void takeSkipsOutOfRangeKeys() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = 1; + + final KEMSignedPreKey validKey = KeysHelper.signedKEMPreKey(1, IDENTITY_KEY_PAIR); + final KEMSignedPreKey outOfRange1 = KeysHelper.signedKEMPreKey(KeyIdUtil.MAX_KEY_ID + 1, IDENTITY_KEY_PAIR); + final KEMSignedPreKey outOfRange2 = KeysHelper.signedKEMPreKey(KeyIdUtil.MAX_KEY_ID + 2, IDENTITY_KEY_PAIR); + + keyStore.store(accountIdentifier, deviceId, List.of(outOfRange1, outOfRange2, validKey)).join(); + assertEquals(Optional.of(validKey), keyStore.take(accountIdentifier, deviceId).join()); + assertEquals(Optional.empty(), keyStore.take(accountIdentifier, deviceId).join()); + assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join()); + + } + private List generateRandomPreKeys() { final Set keyIds = new HashSet<>(KEY_COUNT); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java index 635a8e585..007fe45ab 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java @@ -37,11 +37,17 @@ class RepeatedUseECSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTes @Override protected ECSignedPreKey generateSignedPreKey() { - return KeysHelper.signedECPreKey(currentKeyId++, IDENTITY_KEY_PAIR); + return generateSignedPreKey(currentKeyId++); + } + + @Override + protected ECSignedPreKey generateSignedPreKey(long keyId) { + return KeysHelper.signedECPreKey(keyId, IDENTITY_KEY_PAIR); } @Override protected DynamoDbClient getDynamoDbClient() { return DYNAMO_DB_EXTENSION.getDynamoDbClient(); } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java index bc385833d..b991cc4a0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java @@ -42,6 +42,11 @@ class RepeatedUseKEMSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTe @Override protected KEMSignedPreKey generateSignedPreKey() { - return KeysHelper.signedKEMPreKey(currentKeyId++, IDENTITY_KEY_PAIR); + return generateSignedPreKey(currentKeyId++); + } + + @Override + protected KEMSignedPreKey generateSignedPreKey(long keyId) { + return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java index 263317e68..b6fd43f3d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java @@ -8,11 +8,11 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; -import java.util.Map; import java.util.Optional; import java.util.UUID; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest; @@ -22,6 +22,8 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { protected abstract K generateSignedPreKey(); + protected abstract K generateSignedPreKey(long keyId); + protected abstract DynamoDbClient getDynamoDbClient(); @Test @@ -72,4 +74,16 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); assertEquals(Optional.of(retainedPreKey), keys.find(identifier, deviceId2).join()); } + + @Test + void findThrowsOnOutOfRangeKeyId() { + final RepeatedUseSignedPreKeyStore keys = getKeyStore(); + + final UUID identifier = UUID.randomUUID(); + final byte deviceId = 1; + final K outOfRangeKey = generateSignedPreKey(KeyIdUtil.MAX_KEY_ID + 1); + + keys.store(identifier, deviceId, outOfRangeKey).join(); + CompletableFutureTestUtil.assertFailsWithCause(IllegalStateException.class, keys.find(identifier, deviceId)); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java index 1eb7db977..8bdbafd1b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java @@ -155,6 +155,28 @@ class SingleUseECPreKeyStoreTest { assertEquals(0, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join()); } + @Test + void takeSkipsOutOfRangeKeys() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = 1; + + final long outOfRange1 = KeyIdUtil.MAX_KEY_ID + 1; + final long outOfRange2 = KeyIdUtil.MAX_KEY_ID + 2; + final long validKeyId = 1; + + final List preKeys = List.of( + generatePreKey(outOfRange1), + generatePreKey(outOfRange2), + generatePreKey(validKeyId)); + + preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); + + final Optional taken = preKeyStore.take(accountIdentifier, deviceId).join(); + assertEquals(Optional.of(preKeys.get(2)), taken); + assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join()); + assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); + } + private List generateRandomPreKeys() { final Set keyIds = new HashSet<>(KEY_COUNT);