Add a crawler for orphaned prekey pages

This commit is contained in:
Ravi Khadiwala
2025-06-02 11:52:05 -05:00
committed by ravi-signal
parent 2bb14892af
commit aaa36fd8f5
7 changed files with 470 additions and 28 deletions

View File

@@ -0,0 +1,24 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
/**
* The prekey pages stored for a particular device
*
* @param identifier The account identifier or phone number identifier that the keys belong to
* @param deviceId The device identifier
* @param currentPage If present, the active stored page prekeys are being distributed from
* @param pageIdToLastModified The last modified time for all the device's stored pages, keyed by the pageId
*/
public record DeviceKEMPreKeyPages(
UUID identifier, byte deviceId,
Optional<UUID> currentPage,
Map<UUID, Instant> pageIdToLastModified) {}

View File

@@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@@ -12,6 +13,7 @@ import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
public class KeysManager {
@@ -131,8 +133,32 @@ public class KeysManager {
public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) {
return CompletableFuture.allOf(
ecPreKeys.delete(accountUuid, deviceId),
pqPreKeys.delete(accountUuid, deviceId)
ecPreKeys.delete(accountUuid, deviceId),
pqPreKeys.delete(accountUuid, deviceId)
);
}
/**
* List all the current remotely stored prekey pages across all devices. Pages that are no longer in use can be
* removed with {@link #pruneDeadPage}
*
* @param lookupConcurrency the number of concurrent lookup operations to perform when populating list results
* @return All stored prekey pages
*/
public Flux<DeviceKEMPreKeyPages> listStoredKEMPreKeyPages(int lookupConcurrency) {
return pagedPqPreKeys.listStoredPages(lookupConcurrency);
}
/**
* Remove a prekey page that is no longer in use. A page should only be removed if it is not the active page and
* it has no chance of being updated to be.
*
* @param identifier The owner of the dead page
* @param deviceId The device of the dead page
* @param pageId The dead page to remove from storage
* @return A future that completes when the page has been removed
*/
public CompletableFuture<Void> pruneDeadPage(final UUID identifier, final byte deviceId, final UUID pageId) {
return pagedPqPreKeys.deleteBundleFromS3(identifier, deviceId, pageId);
}
}

View File

@@ -12,6 +12,7 @@ import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
@@ -19,6 +20,9 @@ import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -40,9 +44,7 @@ import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.*;
/**
* @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on
@@ -294,6 +296,40 @@ public class PagedSingleUseKEMPreKeyStore {
.thenRun(() -> sample.stop(deleteForDeviceTimer));
}
public Flux<DeviceKEMPreKeyPages> listStoredPages(int lookupConcurrency) {
return Flux
.from(s3AsyncClient.listObjectsV2Paginator(ListObjectsV2Request.builder()
.bucket(bucketName)
.build()))
.flatMapIterable(ListObjectsV2Response::contents)
.map(PagedSingleUseKEMPreKeyStore::parseS3Key)
.bufferUntilChanged(Function.identity(), S3PageKey::fromSameDevice)
.flatMapSequential(pages -> {
final UUID identifier = pages.getFirst().identifier();
final byte deviceId = pages.getFirst().deviceId();
return Mono.fromCompletionStage(() -> dynamoDbAsyncClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)))
// Make sure we get the most up to date pageId to minimize cases where we see a new page in S3 but
// view a stale dynamodb record
.consistentRead(true)
.projectionExpression("#uuid,#deviceid,#pageid")
.expressionAttributeNames(Map.of(
"#uuid", KEY_ACCOUNT_UUID,
"#deviceid", KEY_DEVICE_ID,
"#pageid", ATTR_PAGE_ID))
.build())
.thenApply(getItemResponse -> new DeviceKEMPreKeyPages(
identifier,
deviceId,
Optional.ofNullable(AttributeValues.getUUID(getItemResponse.item(), ATTR_PAGE_ID, null)),
pages.stream().collect(Collectors.toMap(S3PageKey::pageId, S3PageKey::lastModified)))));
}, lookupConcurrency);
}
private CompletableFuture<Void> deleteItems(final UUID identifier,
final Flux<Map<String, AttributeValue>> items) {
return items
@@ -322,6 +358,29 @@ public class PagedSingleUseKEMPreKeyStore {
return String.format("%s/%s/%s", identifier, deviceId, pageId);
}
private record S3PageKey(UUID identifier, byte deviceId, UUID pageId, Instant lastModified) {
boolean fromSameDevice(final S3PageKey other) {
return deviceId == other.deviceId && identifier.equals(other.identifier);
}
}
private static S3PageKey parseS3Key(final S3Object page) {
try {
final String[] parts = page.key().split("/", 3);
if (parts.length != 3 || parts[2].contains("/")) {
throw new IllegalArgumentException("wrong number of path components");
}
return new S3PageKey(
UUID.fromString(parts[0]),
Byte.parseByte(parts[1]),
UUID.fromString(parts[2]), page.lastModified());
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException("invalid s3 page key: " + page.key(), e);
}
}
private CompletableFuture<UUID> writeBundleToS3(final UUID identifier, final byte deviceId,
final ByteBuffer bundle) {
final UUID pageId = UUID.randomUUID();
@@ -332,7 +391,7 @@ public class PagedSingleUseKEMPreKeyStore {
.thenApply(ignoredResponse -> pageId);
}
private CompletableFuture<Void> deleteBundleFromS3(final UUID identifier, final byte deviceId, final UUID pageId) {
CompletableFuture<Void> deleteBundleFromS3(final UUID identifier, final byte deviceId, final UUID pageId) {
return s3AsyncClient.deleteObject(DeleteObjectRequest.builder()
.bucket(bucketName)
.key(s3Key(identifier, deviceId, pageId))

View File

@@ -213,12 +213,14 @@ record CommandDependencies(
.credentialsProvider(awsCredentialsProvider)
.region(Region.of(configuration.getPagedSingleUseKEMPreKeyStore().region()))
.build();
PagedSingleUseKEMPreKeyStore pagedSingleUseKEMPreKeyStore = new PagedSingleUseKEMPreKeyStore(
dynamoDbAsyncClient, asyncKeysS3Client,
configuration.getDynamoDbTables().getPagedKemKeys().getTableName(),
configuration.getPagedSingleUseKEMPreKeyStore().bucket());
KeysManager keys = new KeysManager(
new SingleUseECPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcKeys().getTableName()),
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getKemKeys().getTableName()),
new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient, asyncKeysS3Client,
configuration.getDynamoDbTables().getPagedKemKeys().getTableName(),
configuration.getPagedSingleUseKEMPreKeyStore().bucket()),
pagedSingleUseKEMPreKeyStore,
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,

View File

@@ -0,0 +1,143 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.workers;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.core.Application;
import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.Metrics;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Stream;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.DeviceKEMPreKeyPages;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.retry.Retry;
public class RemoveOrphanedPreKeyPagesCommand extends AbstractCommandWithDependencies {
private final Logger logger = LoggerFactory.getLogger(getClass());
private static final String PAGE_CONSIDERED_COUNTER_NAME = MetricsUtil.name(RemoveOrphanedPreKeyPagesCommand.class,
"pageConsidered");
@VisibleForTesting
static final String DRY_RUN_ARGUMENT = "dry-run";
@VisibleForTesting
static final String CONCURRENCY_ARGUMENT = "concurrency";
private static final int DEFAULT_CONCURRENCY = 10;
@VisibleForTesting
static final String MINIMUM_ORPHAN_AGE_ARGUMENT = "orphan-age";
private static final Duration DEFAULT_MINIMUM_ORPHAN_AGE = Duration.ofDays(7);
private final Clock clock;
public RemoveOrphanedPreKeyPagesCommand(final Clock clock) {
super(new Application<>() {
@Override
public void run(final WhisperServerConfiguration configuration, final Environment environment) {
}
}, "remove-orphaned-pre-key-pages", "Remove pre-key pages that are unreferenced");
this.clock = clock;
}
@Override
public void configure(final Subparser subparser) {
super.configure(subparser);
subparser.addArgument("--concurrency")
.type(Integer.class)
.dest(CONCURRENCY_ARGUMENT)
.required(false)
.setDefault(DEFAULT_CONCURRENCY)
.help("The maximum number of parallel dynamodb operations to process concurrently");
subparser.addArgument("--dry-run")
.type(Boolean.class)
.dest(DRY_RUN_ARGUMENT)
.required(false)
.setDefault(true)
.help("If true, don't actually remove orphaned pre-key pages");
subparser.addArgument("--minimum-orphan-age")
.type(String.class)
.dest(MINIMUM_ORPHAN_AGE_ARGUMENT)
.required(false)
.setDefault(DEFAULT_MINIMUM_ORPHAN_AGE.toString())
.help("Only remove orphans that are at least this old. Provide as an ISO-8601 duration string");
}
@Override
protected void run(final Environment environment, final Namespace namespace,
final WhisperServerConfiguration configuration, final CommandDependencies commandDependencies) throws Exception {
final int concurrency = Objects.requireNonNull(namespace.getInt(CONCURRENCY_ARGUMENT));
final boolean dryRun = Objects.requireNonNull(namespace.getBoolean(DRY_RUN_ARGUMENT));
final Duration orphanAgeMinimum =
Duration.parse(Objects.requireNonNull(namespace.getString(MINIMUM_ORPHAN_AGE_ARGUMENT)));
final Instant olderThan = clock.instant().minus(orphanAgeMinimum);
logger.info("Crawling preKey page store with concurrency={}, processors={}, dryRun={}. Removing orphans written before={}",
concurrency,
Runtime.getRuntime().availableProcessors(),
dryRun,
olderThan);
final KeysManager keysManager = commandDependencies.keysManager();
final int deletedPages = keysManager.listStoredKEMPreKeyPages(concurrency)
.flatMap(storedPages -> Flux.fromStream(getDetetablePages(storedPages, olderThan))
.concatMap(pageId -> dryRun
? Mono.just(0)
: Mono.fromCompletionStage(() ->
keysManager.pruneDeadPage(storedPages.identifier(), storedPages.deviceId(), pageId))
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.thenReturn(1)), concurrency)
.reduce(0, Integer::sum)
.block();
logger.info("Deleted {} orphaned pages", deletedPages);
}
private static Stream<UUID> getDetetablePages(final DeviceKEMPreKeyPages storedPages, final Instant olderThan) {
return storedPages.pageIdToLastModified()
.entrySet()
.stream()
.filter(page -> {
final UUID pageId = page.getKey();
final Instant lastModified = page.getValue();
return shouldDeletePage(storedPages.currentPage(), pageId, olderThan, lastModified);
})
.map(Map.Entry::getKey);
}
@VisibleForTesting
static boolean shouldDeletePage(
final Optional<UUID> currentPage, final UUID page,
final Instant deleteBefore, final Instant lastModified) {
final boolean isCurrentPageForDevice = currentPage.map(uuid -> uuid.equals(page)).orElse(false);
final boolean isStale = lastModified.isBefore(deleteBefore);
Metrics.counter(PAGE_CONSIDERED_COUNTER_NAME,
"isCurrentPageForDevice", Boolean.toString(isCurrentPageForDevice),
"stale", Boolean.toString(isStale))
.increment();
return !isCurrentPageForDevice && isStale;
}
}