Add public methods for fetching accounts asynchronously

This commit is contained in:
Jon Chambers
2023-07-11 23:10:44 -04:00
committed by Jon Chambers
parent 1b7a20619e
commit 41f61c66a3
2 changed files with 366 additions and 9 deletions

View File

@@ -20,6 +20,7 @@ import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.time.Clock;
import java.time.Duration;
import java.util.Arrays;
@@ -39,6 +40,7 @@ import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -58,7 +60,6 @@ import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.ParallelFlux;
import reactor.core.scheduler.Scheduler;
@@ -89,7 +90,7 @@ public class AccountsManager {
@VisibleForTesting
public static final String USERNAME_EXPERIMENT_NAME = "usernames";
private final Logger logger = LoggerFactory.getLogger(AccountsManager.class);
private static final Logger logger = LoggerFactory.getLogger(AccountsManager.class);
private final Accounts accounts;
private final PhoneNumberIdentifiers phoneNumberIdentifiers;
@@ -686,6 +687,14 @@ public class AccountsManager {
);
}
public CompletableFuture<Optional<Account>> getByE164Async(final String number) {
return checkRedisThenAccountsAsync(
getByNumberTimer,
() -> redisGetBySecondaryKeyAsync(getAccountMapKey(number), redisNumberGetTimer),
() -> accounts.getByE164Async(number)
);
}
public Optional<Account> getByPhoneNumberIdentifier(final UUID pni) {
return checkRedisThenAccounts(
getByNumberTimer,
@@ -694,6 +703,14 @@ public class AccountsManager {
);
}
public CompletableFuture<Optional<Account>> getByPhoneNumberIdentifierAsync(final UUID pni) {
return checkRedisThenAccountsAsync(
getByNumberTimer,
() -> redisGetBySecondaryKeyAsync(getAccountMapKey(pni.toString()), redisPniGetTimer),
() -> accounts.getByPhoneNumberIdentifierAsync(pni)
);
}
public Optional<Account> getByUsernameLinkHandle(final UUID usernameLinkHandle) {
return checkRedisThenAccounts(
getByUsernameLinkHandleTimer,
@@ -718,6 +735,14 @@ public class AccountsManager {
);
}
public CompletableFuture<Optional<Account>> getByAccountIdentifierAsync(final UUID uuid) {
return checkRedisThenAccountsAsync(
getByUuidTimer,
() -> redisGetByAccountIdentifierAsync(uuid),
() -> accounts.getByAccountIdentifierAsync(uuid)
);
}
public UUID getPhoneNumberIdentifier(String e164) {
return phoneNumberIdentifiers.getPhoneNumberIdentifier(e164);
}
@@ -815,6 +840,36 @@ public class AccountsManager {
}
}
private CompletableFuture<Void> redisSetAsync(final Account account) {
final String accountJson;
try {
accountJson = mapper.writeValueAsString(account);
} catch (final JsonProcessingException e) {
throw new UncheckedIOException(e);
}
return cacheCluster.withCluster(connection -> CompletableFuture.allOf(
connection.async().setex(
getAccountMapKey(account.getPhoneNumberIdentifier().toString()), CACHE_TTL_SECONDS,
account.getUuid().toString())
.toCompletableFuture(),
connection.async()
.setex(getAccountMapKey(account.getNumber()), CACHE_TTL_SECONDS, account.getUuid().toString())
.toCompletableFuture(),
connection.async().setex(getAccountEntityKey(account.getUuid()), CACHE_TTL_SECONDS, accountJson)
.toCompletableFuture(),
account.getUsernameHash()
.map(usernameHash -> connection.async()
.setex(getUsernameHashAccountMapKey(usernameHash), CACHE_TTL_SECONDS, account.getUuid().toString())
.toCompletableFuture())
.orElseGet(() -> CompletableFuture.completedFuture(null))
));
}
private Optional<Account> checkRedisThenAccounts(
final Timer overallTimer,
final Supplier<Optional<Account>> resolveFromRedis,
@@ -829,6 +884,23 @@ public class AccountsManager {
}
}
private CompletableFuture<Optional<Account>> checkRedisThenAccountsAsync(
final Timer overallTimer,
final Supplier<CompletableFuture<Optional<Account>>> resolveFromRedis,
final Supplier<CompletableFuture<Optional<Account>>> resolveFromAccounts) {
@SuppressWarnings("resource") final Timer.Context timerContext = overallTimer.time();
return resolveFromRedis.get()
.thenCompose(maybeAccountFromRedis -> maybeAccountFromRedis
.map(accountFromRedis -> CompletableFuture.completedFuture(maybeAccountFromRedis))
.orElseGet(() -> resolveFromAccounts.get()
.thenCompose(maybeAccountFromAccounts -> maybeAccountFromAccounts
.map(account -> redisSetAsync(account).thenApply(ignored -> maybeAccountFromAccounts))
.orElseGet(() -> CompletableFuture.completedFuture(maybeAccountFromAccounts)))))
.whenComplete((ignored, throwable) -> timerContext.close());
}
private Optional<Account> redisGetBySecondaryKey(final String secondaryKey, final Timer timer) {
try (final Timer.Context ignored = timer.time()) {
return Optional.ofNullable(cacheCluster.withCluster(connection -> connection.sync().get(secondaryKey)))
@@ -843,12 +915,50 @@ public class AccountsManager {
}
}
private CompletableFuture<Optional<Account>> redisGetBySecondaryKeyAsync(final String secondaryKey, final Timer timer) {
@SuppressWarnings("resource") final Timer.Context timerContext = timer.time();
return cacheCluster.withCluster(connection -> connection.async().get(secondaryKey))
.thenCompose(nullableUuid -> {
if (nullableUuid != null) {
return getByAccountIdentifierAsync(UUID.fromString(nullableUuid));
} else {
return CompletableFuture.completedFuture(Optional.empty());
}
})
.exceptionally(throwable -> {
logger.warn("Failed to retrieve account from Redis", throwable);
return Optional.empty();
})
.whenComplete((ignored, throwable) -> timerContext.close())
.toCompletableFuture();
}
private Optional<Account> redisGetByAccountIdentifier(UUID uuid) {
try (Timer.Context ignored = redisUuidGetTimer.time()) {
final String json = cacheCluster.withCluster(connection -> connection.sync().get(getAccountEntityKey(uuid)));
if (json != null) {
Account account = mapper.readValue(json, Account.class);
return parseAccountJson(json, uuid);
} catch (final RedisException e) {
logger.warn("Redis failure", e);
return Optional.empty();
}
}
private CompletableFuture<Optional<Account>> redisGetByAccountIdentifierAsync(final UUID uuid) {
return cacheCluster.withCluster(connection -> connection.async().get(getAccountEntityKey(uuid)))
.thenApply(accountJson -> parseAccountJson(accountJson, uuid))
.exceptionally(throwable -> {
logger.warn("Failed to retrieve account from Redis", throwable);
return Optional.empty();
})
.toCompletableFuture();
}
private static Optional<Account> parseAccountJson(@Nullable final String accountJson, final UUID uuid) {
try {
if (StringUtils.isNotBlank(accountJson)) {
Account account = mapper.readValue(accountJson, Account.class);
account.setUuid(uuid);
if (account.getPhoneNumberIdentifier() == null) {
@@ -859,12 +969,9 @@ public class AccountsManager {
}
return Optional.empty();
} catch (IOException e) {
} catch (final IOException e) {
logger.warn("Deserialization error", e);
return Optional.empty();
} catch (RedisException e) {
logger.warn("Redis failure", e);
return Optional.empty();
}
}