Add plumbing for a "wait for transfer archive" system

This commit is contained in:
Jon Chambers
2024-10-10 11:59:37 -04:00
committed by Jon Chambers
parent 0adaa331a1
commit 7ff48155d6
3 changed files with 307 additions and 55 deletions

View File

@@ -49,6 +49,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
@@ -69,6 +70,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
@@ -115,7 +117,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final Accounts accounts;
private final PhoneNumberIdentifiers phoneNumberIdentifiers;
private final FaultTolerantRedisClusterClient cacheCluster;
private final FaultTolerantRedisClient pubSubRedisSingleton;
private final FaultTolerantRedisClient pubSubRedisClient;
private final AccountLockManager accountLockManager;
private final KeysManager keysManager;
private final MessagesManager messagesManager;
@@ -137,11 +139,19 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final Map<String, CompletableFuture<Optional<DeviceInfo>>> waitForDeviceFuturesByTokenIdentifier =
new ConcurrentHashMap<>();
private final Map<TimestampedDeviceIdentifier, CompletableFuture<Optional<RemoteAttachment>>> waitForTransferArchiveFuturesByDeviceIdentifier =
new ConcurrentHashMap<>();
private static final int SHA256_HASH_LENGTH = getSha256MessageDigest().getDigestLength();
private static final Duration RECENTLY_ADDED_DEVICE_TTL = Duration.ofHours(1);
private static final String LINKED_DEVICE_PREFIX = "linked_device::";
private static final String LINKED_DEVICE_KEYSPACE_PATTERN = "__keyspace@0__:" + LINKED_DEVICE_PREFIX + "*";
private static final Duration RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL = Duration.ofHours(1);
private static final String TRANSFER_ARCHIVE_PREFIX = "transfer_archive::";
private static final String TRANSFER_ARCHIVE_KEYSPACE_PATTERN = "__keyspace@0__:" + TRANSFER_ARCHIVE_PREFIX + "*";
private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, List.of("uuid")));
@@ -173,10 +183,13 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
}
private record TimestampedDeviceIdentifier(UUID accountIdentifier, byte deviceId, Instant deviceCreationTimestamp) {
}
public AccountsManager(final Accounts accounts,
final PhoneNumberIdentifiers phoneNumberIdentifiers,
final FaultTolerantRedisClusterClient cacheCluster,
final FaultTolerantRedisClient pubSubRedisSingleton,
final FaultTolerantRedisClient pubSubRedisClient,
final AccountLockManager accountLockManager,
final KeysManager keysManager,
final MessagesManager messagesManager,
@@ -194,7 +207,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
this.accounts = accounts;
this.phoneNumberIdentifiers = phoneNumberIdentifiers;
this.cacheCluster = cacheCluster;
this.pubSubRedisSingleton = pubSubRedisSingleton;
this.pubSubRedisClient = pubSubRedisClient;
this.accountLockManager = accountLockManager;
this.keysManager = keysManager;
this.messagesManager = messagesManager;
@@ -218,19 +231,23 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
throw new IllegalArgumentException(e);
}
this.pubSubConnection = pubSubRedisSingleton.createPubSubConnection();
this.pubSubConnection = pubSubRedisClient.createPubSubConnection();
}
@Override
public void start() {
pubSubConnection.usePubSubConnection(connection -> connection.addListener(this));
pubSubConnection.usePubSubConnection(connection -> connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN));
pubSubConnection.usePubSubConnection(connection -> {
connection.addListener(this);
connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN, TRANSFER_ARCHIVE_KEYSPACE_PATTERN);
});
}
@Override
public void stop() {
pubSubConnection.usePubSubConnection(connection -> connection.sync().punsubscribe());
pubSubConnection.usePubSubConnection(connection -> connection.removeListener(this));
pubSubConnection.usePubSubConnection(connection -> {
connection.sync().punsubscribe();
connection.removeListener(this);
});
}
public Account create(final String number,
@@ -409,7 +426,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
throw new UncheckedIOException(e);
}
pubSubRedisSingleton.withConnection(connection ->
pubSubRedisClient.withConnection(connection ->
connection.async().set(key, deviceInfoJson, SetArgs.Builder.ex(RECENTLY_ADDED_DEVICE_TTL)))
.whenComplete((ignored, pubSubThrowable) -> {
if (pubSubThrowable != null) {
@@ -1406,51 +1423,11 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier"));
}
final CompletableFuture<Optional<DeviceInfo>> waitForDeviceFuture = new CompletableFuture<>();
waitForDeviceFuture
.completeOnTimeout(Optional.empty(), TimeUnit.MILLISECONDS.convert(timeout), TimeUnit.MILLISECONDS)
.whenComplete((maybeDevice, throwable) -> waitForDeviceFuturesByTokenIdentifier.compute(linkDeviceTokenIdentifier,
(ignored, existingFuture) -> {
// Only remove the future from the map if it's THIS future, and not one that later displaced this one
return existingFuture == waitForDeviceFuture ? null : existingFuture;
}));
{
final CompletableFuture<Optional<DeviceInfo>> displacedFuture =
waitForDeviceFuturesByTokenIdentifier.put(linkDeviceTokenIdentifier, waitForDeviceFuture);
if (displacedFuture != null) {
displacedFuture.complete(Optional.empty());
}
}
// The device may already have been linked by the time the caller started watching for it; perform an immediate
// check to see if the device is already there.
pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(linkDeviceTokenIdentifier)))
.thenAccept(response -> {
if (StringUtils.isNotBlank(response)) {
handleDeviceAdded(waitForDeviceFuture, response);
}
});
return waitForDeviceFuture;
}
private static String getLinkedDeviceKey(final String linkDeviceTokenIdentifier) {
return LINKED_DEVICE_PREFIX + linkDeviceTokenIdentifier;
}
@Override
public void message(final String pattern, final String channel, final String message) {
if (LINKED_DEVICE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String tokenIdentifier = channel.substring(LINKED_DEVICE_KEYSPACE_PATTERN.length() - 1);
Optional.ofNullable(waitForDeviceFuturesByTokenIdentifier.remove(tokenIdentifier))
.ifPresent(future -> pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(tokenIdentifier)))
.thenAccept(deviceInfoJson -> handleDeviceAdded(future, deviceInfoJson)));
}
return waitForPubSubKey(waitForDeviceFuturesByTokenIdentifier,
linkDeviceTokenIdentifier,
getLinkedDeviceKey(linkDeviceTokenIdentifier),
timeout,
this::handleDeviceAdded);
}
private void handleDeviceAdded(final CompletableFuture<Optional<DeviceInfo>> future, final String deviceInfoJson) {
@@ -1462,6 +1439,134 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
}
private static String getLinkedDeviceKey(final String linkDeviceTokenIdentifier) {
return LINKED_DEVICE_PREFIX + linkDeviceTokenIdentifier;
}
public CompletableFuture<Optional<RemoteAttachment>> waitForTransferArchive(final Account account, final Device device, final Duration timeout) {
final TimestampedDeviceIdentifier deviceIdentifier =
new TimestampedDeviceIdentifier(account.getIdentifier(IdentityType.ACI),
device.getId(),
Instant.ofEpochMilli(device.getCreated()));
return waitForPubSubKey(waitForTransferArchiveFuturesByDeviceIdentifier,
deviceIdentifier,
getTransferArchiveKey(account.getIdentifier(IdentityType.ACI), device.getId(), Instant.ofEpochMilli(device.getCreated())),
timeout,
this::handleTransferArchiveAdded);
}
public CompletableFuture<Void> recordTransferArchiveUpload(final Account account,
final byte destinationDeviceId,
final Instant destinationDeviceCreationTimestamp,
final RemoteAttachment transferArchive) {
final String key = getTransferArchiveKey(account.getIdentifier(IdentityType.ACI),
destinationDeviceId,
destinationDeviceCreationTimestamp);
try {
final String transferArchiveJson = SystemMapper.jsonMapper().writeValueAsString(transferArchive);
return pubSubRedisClient.withConnection(connection ->
connection.async().set(key, transferArchiveJson, SetArgs.Builder.ex(RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL)))
.thenRun(Util.NOOP)
.toCompletableFuture();
} catch (final JsonProcessingException e) {
// This should never happen for well-defined objects we control
throw new UncheckedIOException(e);
}
}
private void handleTransferArchiveAdded(final CompletableFuture<Optional<RemoteAttachment>> future, final String transferArchiveJson) {
try {
future.complete(Optional.of(SystemMapper.jsonMapper().readValue(transferArchiveJson, RemoteAttachment.class)));
} catch (final JsonProcessingException e) {
logger.error("Could not parse transfer archive json", e);
future.completeExceptionally(e);
}
}
private static String getTransferArchiveKey(final UUID accountIdentifier,
final byte destinationDeviceId,
final Instant destinationDeviceCreationTimestamp) {
return TRANSFER_ARCHIVE_PREFIX + accountIdentifier.toString() +
":" + destinationDeviceId +
":" + destinationDeviceCreationTimestamp.toEpochMilli();
}
private <K, T> CompletableFuture<Optional<T>> waitForPubSubKey(final Map<K, CompletableFuture<Optional<T>>> futureMap,
final K mapKey,
final String redisKey,
final Duration timeout,
final BiConsumer<CompletableFuture<Optional<T>>, String> handler) {
final CompletableFuture<Optional<T>> future = new CompletableFuture<>();
future.completeOnTimeout(Optional.empty(), TimeUnit.MILLISECONDS.convert(timeout), TimeUnit.MILLISECONDS)
.whenComplete((maybeBackup, throwable) -> futureMap.remove(mapKey, future));
{
final CompletableFuture<Optional<T>> displacedFuture = futureMap.put(mapKey, future);
if (displacedFuture != null) {
displacedFuture.complete(Optional.empty());
}
}
// The Redis key we're waiting for may have been added before the caller issued a request to watch for it; check to
// see if it's already there
pubSubRedisClient.withConnection(connection -> connection.async().get(redisKey))
.thenAccept(response -> {
if (StringUtils.isNotBlank(response)) {
handler.accept(future, response);
}
});
return future;
}
@Override
public void message(final String pattern, final String channel, final String message) {
if (LINKED_DEVICE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String tokenIdentifier = channel.substring(LINKED_DEVICE_KEYSPACE_PATTERN.length() - 1);
Optional.ofNullable(waitForDeviceFuturesByTokenIdentifier.remove(tokenIdentifier))
.ifPresent(future -> pubSubRedisClient.withConnection(connection -> connection.async().get(getLinkedDeviceKey(tokenIdentifier)))
.thenAccept(deviceInfoJson -> handleDeviceAdded(future, deviceInfoJson)));
} else if (TRANSFER_ARCHIVE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String[] deviceIdentifierComponents =
channel.substring(TRANSFER_ARCHIVE_KEYSPACE_PATTERN.length() - 1).split(":", 3);
if (deviceIdentifierComponents.length != 3) {
logger.error("Could not parse timestamped device identifier; unexpected component count");
return;
}
try {
final TimestampedDeviceIdentifier deviceIdentifier;
final String transferArchiveKey;
{
final UUID accountIdentifier = UUID.fromString(deviceIdentifierComponents[0]);
final byte deviceId = Byte.parseByte(deviceIdentifierComponents[1]);
final Instant deviceCreationTimestamp = Instant.ofEpochMilli(Long.parseLong(deviceIdentifierComponents[2]));
deviceIdentifier = new TimestampedDeviceIdentifier(accountIdentifier, deviceId, deviceCreationTimestamp);
transferArchiveKey = getTransferArchiveKey(accountIdentifier, deviceId, deviceCreationTimestamp);
}
Optional.ofNullable(waitForTransferArchiveFuturesByDeviceIdentifier.remove(deviceIdentifier))
.ifPresent(future -> pubSubRedisClient.withConnection(connection -> connection.async().get(transferArchiveKey))
.thenAccept(transferArchiveJson -> handleTransferArchiveAdded(future, transferArchiveJson)));
} catch (final IllegalArgumentException e) {
logger.error("Could not parse timestamped device identifier", e);
}
}
}
private static MessageDigest getSha256MessageDigest() {
try {
return MessageDigest.getInstance("SHA-256");