Always read from new and old PQ prekey stores, add experiment to start writing to new prekey store

This commit is contained in:
ravi-signal
2025-07-09 09:17:17 -05:00
committed by GitHub
parent 80c11e7eda
commit c9f21d5970
10 changed files with 184 additions and 36 deletions

View File

@@ -368,6 +368,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
MetricsUtil.configureRegistries(config, environment, dynamicConfigurationManager);
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager);
if (config.getServerFactory() instanceof DefaultServerFactory defaultServerFactory) {
defaultServerFactory.getApplicationConnectors()
.forEach(connectorFactory -> {
@@ -444,7 +446,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPagedKemKeys().getTableName(),
config.getPagedSingleUseKEMPreKeyStore().bucket()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getKemLastResortKeys().getTableName()));
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getKemLastResortKeys().getTableName()),
experimentEnrollmentManager);
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration(),
@@ -604,8 +607,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ExternalServiceCredentialsGenerator svr2CredentialsGenerator = SecureValueRecovery2Controller.credentialsGenerator(
config.getSvr2Configuration());
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(
dynamicConfigurationManager);
RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager =
new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords);
UsernameHashZkProofVerifier usernameHashZkProofVerifier = new UsernameHashZkProofVerifier();

View File

@@ -5,7 +5,7 @@
package org.whispersystems.textsecuregcm.storage;
import java.time.Instant;
import io.micrometer.core.instrument.Metrics;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@@ -13,6 +13,8 @@ import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
@@ -23,18 +25,25 @@ public class KeysManager {
private final PagedSingleUseKEMPreKeyStore pagedPqPreKeys;
private final RepeatedUseECSignedPreKeyStore ecSignedPreKeys;
private final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
public static String PAGED_KEYS_EXPERIMENT_NAME = "pagedPreKeys";
private static final String TAKE_PQ_NAME = MetricsUtil.name(KeysManager.class, "takePq");
public KeysManager(
final SingleUseECPreKeyStore ecPreKeys,
final SingleUseKEMPreKeyStore pqPreKeys,
final PagedSingleUseKEMPreKeyStore pagedPqPreKeys,
final RepeatedUseECSignedPreKeyStore ecSignedPreKeys,
final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys) {
final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.ecPreKeys = ecPreKeys;
this.pqPreKeys = pqPreKeys;
this.pagedPqPreKeys = pagedPqPreKeys;
this.ecSignedPreKeys = ecSignedPreKeys;
this.pqLastResortKeys = pqLastResortKeys;
this.experimentEnrollmentManager = experimentEnrollmentManager;
}
public TransactWriteItem buildWriteItemForEcSignedPreKey(final UUID identifier,
@@ -79,22 +88,31 @@ public class KeysManager {
);
}
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final byte deviceId, final ECSignedPreKey ecSignedPreKey) {
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final byte deviceId,
final ECSignedPreKey ecSignedPreKey) {
return ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey);
}
public CompletableFuture<Void> storePqLastResort(final UUID identifier, final byte deviceId, final KEMSignedPreKey lastResortKey) {
public CompletableFuture<Void> storePqLastResort(final UUID identifier, final byte deviceId,
final KEMSignedPreKey lastResortKey) {
return pqLastResortKeys.store(identifier, deviceId, lastResortKey);
}
public CompletableFuture<Void> storeEcOneTimePreKeys(final UUID identifier, final byte deviceId,
final List<ECPreKey> preKeys) {
final List<ECPreKey> preKeys) {
return ecPreKeys.store(identifier, deviceId, preKeys);
}
public CompletableFuture<Void> storeKemOneTimePreKeys(final UUID identifier, final byte deviceId,
final List<KEMSignedPreKey> preKeys) {
return pqPreKeys.store(identifier, deviceId, preKeys);
final List<KEMSignedPreKey> preKeys) {
final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME);
final CompletableFuture<Void> deleteOtherKeys = enrolledInPagedKeys
? pqPreKeys.delete(identifier, deviceId)
: pagedPqPreKeys.delete(identifier, deviceId);
return deleteOtherKeys.thenCompose(ignored -> enrolledInPagedKeys
? pagedPqPreKeys.store(identifier, deviceId, preKeys)
: pqPreKeys.store(identifier, deviceId, preKeys));
}
public CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final byte deviceId) {
@@ -102,10 +120,36 @@ public class KeysManager {
}
public CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final byte deviceId) {
return pqPreKeys.take(identifier, deviceId)
final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME);
return tagTakePQ(pagedPqPreKeys.take(identifier, deviceId), PQSource.PAGE, enrolledInPagedKeys)
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(ignored -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> tagTakePQ(pqPreKeys.take(identifier, deviceId), PQSource.ROW, enrolledInPagedKeys)))
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> pqLastResortKeys.find(identifier, deviceId)));
.orElseGet(() -> tagTakePQ(pqLastResortKeys.find(identifier, deviceId), PQSource.LAST_RESORT, enrolledInPagedKeys)));
}
private enum PQSource {
PAGE,
ROW,
LAST_RESORT
}
private CompletableFuture<Optional<KEMSignedPreKey>> tagTakePQ(CompletableFuture<Optional<KEMSignedPreKey>> prekey, final PQSource source, final boolean enrolledInPagedKeys) {
return prekey.thenApply(maybeSingleUsePreKey -> {
final Optional<String> maybeSourceTag = maybeSingleUsePreKey
// If we found a PK, use this source tag
.map(ignore -> source.name())
// If we didn't and this is our last resort, we didn't find a PK
.or(() -> source == PQSource.LAST_RESORT ? Optional.of("absent") : Optional.empty());
maybeSourceTag.ifPresent(sourceTag -> {
Metrics.counter(TAKE_PQ_NAME,
"source", sourceTag,
"enrolled", Boolean.toString(enrolledInPagedKeys))
.increment();
});
return maybeSingleUsePreKey;
});
}
public CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final byte deviceId) {
@@ -121,20 +165,24 @@ public class KeysManager {
}
public CompletableFuture<Integer> getPqCount(final UUID identifier, final byte deviceId) {
return pqPreKeys.getCount(identifier, deviceId);
return pagedPqPreKeys.getCount(identifier, deviceId).thenCompose(count -> count == 0
? pqPreKeys.getCount(identifier, deviceId)
: CompletableFuture.completedFuture(count));
}
public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID identifier) {
return CompletableFuture.allOf(
ecPreKeys.delete(identifier),
pqPreKeys.delete(identifier)
pqPreKeys.delete(identifier),
pagedPqPreKeys.delete(identifier)
);
}
public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) {
return CompletableFuture.allOf(
ecPreKeys.delete(accountUuid, deviceId),
pqPreKeys.delete(accountUuid, deviceId)
pqPreKeys.delete(accountUuid, deviceId),
pagedPqPreKeys.delete(accountUuid, deviceId)
);
}

View File

@@ -22,7 +22,6 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -44,7 +43,12 @@ import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.*;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Object;
/**
* @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on

View File

@@ -33,6 +33,7 @@ import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
@@ -122,6 +123,9 @@ record CommandDependencies(
new DynamicConfigurationManager<>(
configuration.getDynamicConfig().build(awsCredentialsProvider, dynamicConfigurationExecutor), DynamicConfiguration.class);
dynamicConfigurationManager.start();
ExperimentEnrollmentManager experimentEnrollmentManager =
new ExperimentEnrollmentManager(dynamicConfigurationManager);
final ClientResources.Builder redisClientResourcesBuilder = ClientResources.builder();
FaultTolerantRedisClusterClient cacheCluster = configuration.getCacheClusterConfiguration()
@@ -224,7 +228,8 @@ record CommandDependencies(
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName()));
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName()),
experimentEnrollmentManager);
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),