Use registration ID or creation timestamp in the transfer archive flow

This commit is contained in:
Katherine
2025-07-30 15:32:49 -04:00
committed by GitHub
parent 30774bbc40
commit db4c71368c
6 changed files with 308 additions and 87 deletions

View File

@@ -115,6 +115,8 @@ public class DeviceController {
private static final String WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME =
MetricsUtil.name(DeviceController.class, "waitForTransferArchiveDuration");
private static final String RECORD_TRANSFER_ARCHIVE_UPLOADED_COUNTER_NAME = MetricsUtil.name(DeviceController.class, "recordTransferArchiveUploaded");
private static final String HAS_REGISTRATION_ID_TAG_NAME = "hasRegistrationId";
@VisibleForTesting
static final int MIN_TOKEN_IDENTIFIER_LENGTH = 32;
@@ -533,8 +535,14 @@ public class DeviceController {
@ApiResponse(responseCode = "422", description = "The request object could not be parsed or was otherwise invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletionStage<Void> recordTransferArchiveUploaded(@Auth final AuthenticatedDevice authenticatedDevice,
@NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) {
@NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest,
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent) {
Metrics.counter(RECORD_TRANSFER_ARCHIVE_UPLOADED_COUNTER_NAME, Tags.of(
UserAgentTagUtil.getPlatformTag(userAgent),
io.micrometer.core.instrument.Tag.of(
HAS_REGISTRATION_ID_TAG_NAME,
String.valueOf(transferArchiveUploadedRequest.registrationId().isPresent()))
)).increment();
return rateLimiters.getUploadTransferArchiveLimiter()
.validateAsync(authenticatedDevice.accountIdentifier())
.thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.accountIdentifier()))
@@ -544,7 +552,8 @@ public class DeviceController {
return accounts.recordTransferArchiveUpload(account,
transferArchiveUploadedRequest.destinationDeviceId(),
Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()),
transferArchiveUploadedRequest.destinationDeviceCreated().map(Instant::ofEpochMilli),
transferArchiveUploadedRequest.registrationId(),
transferArchiveUploadedRequest.transferArchive());
});
}

View File

@@ -7,21 +7,30 @@ package org.whispersystems.textsecuregcm.entities;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Positive;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.Optional;
public record TransferArchiveUploadedRequest(
@Min(1)
@Max(Device.MAXIMUM_DEVICE_ID)
@Schema(description = "The ID of the device for which the transfer archive has been prepared")
byte destinationDeviceId,
@Positive
@Schema(description = "The timestamp, in milliseconds since the epoch, at which the destination device was created")
long destinationDeviceCreated,
@Schema(description = """
The timestamp, in milliseconds since the epoch, at which the destination device was created.
Deprecated in favor of registrationId.
""", deprecated = true)
@Deprecated
Optional<@Positive Long> destinationDeviceCreated,
@Schema(description = "The registration ID of the destination device")
Optional<@Min(0) @Max(Device.MAX_REGISTRATION_ID) Integer> registrationId,
@NotNull
@Valid
@@ -29,4 +38,10 @@ public record TransferArchiveUploadedRequest(
The location of the transfer archive if the archive was successfully uploaded, otherwise a error indicating that
the upload has failed and the destination device should stop waiting
""", oneOf = {RemoteAttachment.class, RemoteAttachmentError.class})
TransferArchiveResult transferArchive) {}
TransferArchiveResult transferArchive) {
@AssertTrue
@Schema(hidden = true)
public boolean isExactlyOneDisambiguatorProvided() {
return destinationDeviceCreated.isPresent() ^ registrationId.isPresent();
}
}

View File

@@ -50,6 +50,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@@ -85,6 +86,7 @@ import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryClient;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.RegistrationIdValidator;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
@@ -111,6 +113,8 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private static final String DELETE_COUNTER_NAME = name(AccountsManager.class, "deleteCounter");
private static final String COUNTRY_CODE_TAG_NAME = "country";
private static final String DELETION_REASON_TAG_NAME = "reason";
private static final String TIMESTAMP_BASED_TRANSFER_ARCHIVE_KEY_COUNTER_NAME = name(AccountsManager.class, "timestampRedisKeyCounter");
private static final String REGISTRATION_ID_BASED_TRANSFER_ARCHIVE_KEY_COUNTER_NAME = name(AccountsManager.class,"registrationIdRedisKeyCounter");
private static final Logger logger = LoggerFactory.getLogger(AccountsManager.class);
@@ -140,7 +144,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final Map<String, CompletableFuture<Optional<DeviceInfo>>> waitForDeviceFuturesByTokenIdentifier =
new ConcurrentHashMap<>();
private final Map<TimestampedDeviceIdentifier, CompletableFuture<Optional<TransferArchiveResult>>> waitForTransferArchiveFuturesByDeviceIdentifier =
private final Map<DeviceIdentifier, CompletableFuture<Optional<TransferArchiveResult>>> waitForTransferArchiveFuturesByDeviceIdentifier =
new ConcurrentHashMap<>();
private final Map<String, CompletableFuture<Optional<RestoreAccountRequest>>> waitForRestoreAccountRequestFuturesByToken =
@@ -155,6 +159,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private static final Duration RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL = Duration.ofHours(1);
private static final String TRANSFER_ARCHIVE_PREFIX = "transfer_archive::";
private static final String TRANSFER_ARCHIVE_KEYSPACE_PATTERN = "__keyspace@0__:" + TRANSFER_ARCHIVE_PREFIX + "*";
private static final String TRANSFER_ARCHIVE_REGISTRATION_ID_PATTERN = "registrationId";
private static final Duration RESTORE_ACCOUNT_REQUEST_TTL = Duration.ofHours(1);
private static final String RESTORE_ACCOUNT_REQUEST_PREFIX = "restore_account::";
@@ -194,7 +199,14 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
}
private record TimestampedDeviceIdentifier(UUID accountIdentifier, byte deviceId, Instant deviceCreationTimestamp) {
private interface DeviceIdentifier {}
private record TimestampDeviceIdentifier(UUID accountIdentifier, byte deviceId, Instant deviceCreationTimestamp)
implements DeviceIdentifier {
}
private record RegistrationIdDeviceIdentifier(UUID accountIdentifier, byte deviceId,
int registrationId) implements DeviceIdentifier {
}
public AccountsManager(final Accounts accounts,
@@ -1509,34 +1521,66 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
public CompletableFuture<Optional<TransferArchiveResult>> waitForTransferArchive(final Account account, final Device device, final Duration timeout) {
final TimestampedDeviceIdentifier deviceIdentifier =
new TimestampedDeviceIdentifier(account.getIdentifier(IdentityType.ACI),
device.getId(),
Instant.ofEpochMilli(device.getCreated()));
final DeviceIdentifier timestampDeviceIdentifier = new TimestampDeviceIdentifier(account.getIdentifier(IdentityType.ACI), device.getId(), Instant.ofEpochMilli(device.getCreated()));
final String timestampTransferArchiveKey = getTimestampTransferArchiveKey(account.getIdentifier(IdentityType.ACI), device.getId(), Instant.ofEpochMilli(device.getCreated()));
return waitForPubSubKey(waitForTransferArchiveFuturesByDeviceIdentifier,
deviceIdentifier,
getTransferArchiveKey(account.getIdentifier(IdentityType.ACI), device.getId(), Instant.ofEpochMilli(device.getCreated())),
final DeviceIdentifier registrationIdDeviceIdentifier = new RegistrationIdDeviceIdentifier(account.getIdentifier(IdentityType.ACI), device.getId(), device.getRegistrationId(IdentityType.ACI));
final String registrationIdTransferArchiveKey = getRegistrationIdTransferArchiveKey(account.getIdentifier(IdentityType.ACI), device.getId(), device.getRegistrationId(IdentityType.ACI));
final CompletableFuture<Optional<TransferArchiveResult>> timestampFuture = waitForPubSubKey(waitForTransferArchiveFuturesByDeviceIdentifier,
timestampDeviceIdentifier,
timestampTransferArchiveKey,
timeout,
this::handleTransferArchiveAdded);
final CompletableFuture<Optional<TransferArchiveResult>> registrationIdFuture = waitForPubSubKey(waitForTransferArchiveFuturesByDeviceIdentifier,
registrationIdDeviceIdentifier,
registrationIdTransferArchiveKey,
timeout,
this::handleTransferArchiveAdded);
return firstSuccessfulTransferArchiveFuture(List.of(timestampFuture, registrationIdFuture));
}
@VisibleForTesting
static CompletableFuture<Optional<TransferArchiveResult>> firstSuccessfulTransferArchiveFuture(
final List<CompletableFuture<Optional<TransferArchiveResult>>> futures) {
final CompletableFuture<Optional<TransferArchiveResult>> result = new CompletableFuture<>();
final AtomicInteger remaining = new AtomicInteger(futures.size());
for (CompletableFuture<Optional<TransferArchiveResult>> future : futures) {
future.whenComplete((value, _) -> {
if (value.isPresent()) {
result.complete(value);
} else if (remaining.decrementAndGet() == 0) {
result.complete(Optional.empty());
}
});
}
return result;
}
public CompletableFuture<Void> recordTransferArchiveUpload(final Account account,
final byte destinationDeviceId,
final Instant destinationDeviceCreationTimestamp,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Instant> destinationDeviceCreationTimestamp,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Integer> maybeRegistrationId,
final TransferArchiveResult transferArchiveResult) {
final String key = getTransferArchiveKey(account.getIdentifier(IdentityType.ACI),
destinationDeviceId,
destinationDeviceCreationTimestamp);
try {
final String transferArchiveJson = SystemMapper.jsonMapper().writeValueAsString(transferArchiveResult);
return pubSubRedisClient.withConnection(connection ->
connection.async().set(key, transferArchiveJson, SetArgs.Builder.ex(RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL)))
.thenRun(Util.NOOP)
.toCompletableFuture();
return pubSubRedisClient.withConnection(connection -> {
final String key = destinationDeviceCreationTimestamp
.map(timestamp -> getTimestampTransferArchiveKey(account.getIdentifier(IdentityType.ACI), destinationDeviceId, timestamp))
.orElseGet(() -> maybeRegistrationId
.map(registrationId -> getRegistrationIdTransferArchiveKey(account.getIdentifier(IdentityType.ACI), destinationDeviceId, registrationId))
// We validate the request object so this should never happen
.orElseThrow(() -> new AssertionError("No creation timestamp or registration ID provided")));
return connection.async()
.set(key, transferArchiveJson, SetArgs.Builder.ex(RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL))
.thenRun(Util.NOOP)
.toCompletableFuture();
});
} catch (final JsonProcessingException e) {
// This should never happen for well-defined objects we control
throw new UncheckedIOException(e);
@@ -1552,15 +1596,27 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
}
private static String getTransferArchiveKey(final UUID accountIdentifier,
private static String getTimestampTransferArchiveKey(final UUID accountIdentifier,
final byte destinationDeviceId,
final Instant destinationDeviceCreationTimestamp) {
Metrics.counter(TIMESTAMP_BASED_TRANSFER_ARCHIVE_KEY_COUNTER_NAME).increment();
return TRANSFER_ARCHIVE_PREFIX + accountIdentifier.toString() +
":" + destinationDeviceId +
":" + destinationDeviceCreationTimestamp.toEpochMilli();
}
private static String getRegistrationIdTransferArchiveKey(final UUID accountIdentifier,
final byte destinationDeviceId,
final int registrationId) {
Metrics.counter(REGISTRATION_ID_BASED_TRANSFER_ARCHIVE_KEY_COUNTER_NAME).increment();
return TRANSFER_ARCHIVE_PREFIX + accountIdentifier.toString() +
":" + destinationDeviceId +
":" + TRANSFER_ARCHIVE_REGISTRATION_ID_PATTERN +
":" + registrationId;
}
public CompletableFuture<Optional<RestoreAccountRequest>> waitForRestoreAccountRequest(final String token, final Duration timeout) {
return waitForPubSubKey(waitForRestoreAccountRequestFuturesByToken,
token,
@@ -1648,23 +1704,36 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
} else if (TRANSFER_ARCHIVE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String[] deviceIdentifierComponents =
channel.substring(TRANSFER_ARCHIVE_KEYSPACE_PATTERN.length() - 1).split(":", 3);
channel.substring(TRANSFER_ARCHIVE_KEYSPACE_PATTERN.length() - 1).split(":", 4);
if (deviceIdentifierComponents.length != 3) {
logger.error("Could not parse timestamped device identifier; unexpected component count");
if (deviceIdentifierComponents.length != 3 && deviceIdentifierComponents.length != 4) {
logger.error("Could not parse device identifier; unexpected component count");
return;
}
final DeviceIdentifier deviceIdentifier;
final String transferArchiveKey;
try {
final TimestampedDeviceIdentifier deviceIdentifier;
final String transferArchiveKey;
{
final UUID accountIdentifier = UUID.fromString(deviceIdentifierComponents[0]);
final byte deviceId = Byte.parseByte(deviceIdentifierComponents[1]);
final UUID accountIdentifier = UUID.fromString(deviceIdentifierComponents[0]);
final byte deviceId = Byte.parseByte(deviceIdentifierComponents[1]);
if (deviceIdentifierComponents.length == 3) {
// Parse the old transfer archive Redis key format
final Instant deviceCreationTimestamp = Instant.ofEpochMilli(Long.parseLong(deviceIdentifierComponents[2]));
deviceIdentifier = new TimestampedDeviceIdentifier(accountIdentifier, deviceId, deviceCreationTimestamp);
transferArchiveKey = getTransferArchiveKey(accountIdentifier, deviceId, deviceCreationTimestamp);
deviceIdentifier = new TimestampDeviceIdentifier(accountIdentifier, deviceId, deviceCreationTimestamp);
transferArchiveKey = getTimestampTransferArchiveKey(accountIdentifier, deviceId, deviceCreationTimestamp);
} else {
final String maybeRegistrationIdPattern = deviceIdentifierComponents[2];
if (!maybeRegistrationIdPattern.equals(TRANSFER_ARCHIVE_REGISTRATION_ID_PATTERN)) {
throw new IllegalArgumentException("Could not parse Redis key with pattern " + maybeRegistrationIdPattern);
}
final int registrationId = Integer.parseInt(deviceIdentifierComponents[3]);
if (!RegistrationIdValidator.validRegistrationId(registrationId)) {
throw new IllegalArgumentException("Invalid registration ID: " + registrationId);
}
deviceIdentifier = new RegistrationIdDeviceIdentifier(accountIdentifier, deviceId, registrationId);
transferArchiveKey = getRegistrationIdTransferArchiveKey(accountIdentifier, deviceId, registrationId);
}
Optional.ofNullable(waitForTransferArchiveFuturesByDeviceIdentifier.remove(deviceIdentifier))
@@ -1677,7 +1746,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
}));
} catch (final IllegalArgumentException e) {
logger.error("Could not parse timestamped device identifier", e);
logger.error("Could not parse device identifier", e);
}
} else if (RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN.equalsIgnoreCase(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern