Make KeysManager storage/retrieval operations asynchronous

This commit is contained in:
Jon Chambers
2023-06-26 11:17:02 -04:00
committed by Jon Chambers
parent 5847300290
commit f709b00be3
10 changed files with 204 additions and 165 deletions

View File

@@ -17,6 +17,7 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
@@ -340,12 +341,16 @@ public class DeviceController {
keys.delete(a.getUuid(), device.getId());
keys.delete(a.getPhoneNumberIdentifier(), device.getId());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
keys.storeEcSignedPreKeys(a.getUuid(), Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get()));
keys.storePqLastResort(a.getUuid(), Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get()));
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get()));
keys.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get()));
});
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())),
keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())),
keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get())))
.join());
a.addDevice(device);
});

View File

@@ -25,6 +25,7 @@ import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
@@ -95,10 +96,13 @@ public class KeysController {
public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) {
int ecCount = keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
int pqCount = keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
final CompletableFuture<Integer> ecCountFuture =
keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
return new PreKeyCount(ecCount, pqCount);
final CompletableFuture<Integer> pqCountFuture =
keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
return new PreKeyCount(ecCountFuture.join(), pqCountFuture.join());
}
@Timed
@@ -181,8 +185,9 @@ public class KeysController {
}
keys.store(
getIdentifier(account, identityType), device.getId(),
preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey());
getIdentifier(account, identityType), device.getId(),
preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey())
.join();
}
@Timed
@@ -243,8 +248,8 @@ public class KeysController {
for (Device device : devices) {
UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid;
ECSignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null);
KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null;
ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).join().orElse(null);
KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).join().orElse(null) : null;
compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey),
keys.getEcSignedPreKey(identifier, device.getId()));

View File

@@ -23,6 +23,7 @@ import java.time.Instant;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
@@ -176,10 +177,16 @@ public class RegistrationController {
registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
keysManager.storeEcSignedPreKeys(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey().get()));
keysManager.storePqLastResort(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get()));
keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey().get()));
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get()));
CompletableFuture.allOf(
keysManager.storeEcSignedPreKeys(a.getUuid(),
Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey().get())),
keysManager.storePqLastResort(a.getUuid(),
Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get())),
keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey().get())),
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get())))
.join();
});
}

View File

@@ -332,7 +332,7 @@ public class AccountsManager {
if (pniPqLastResortPreKeys != null) {
keysManager.storePqLastResort(
phoneNumberIdentifier,
keysManager.getPqEnabledDevices(uuid).stream().collect(
keysManager.getPqEnabledDevices(uuid).join().stream().collect(
Collectors.toMap(
Function.identity(),
pniPqLastResortPreKeys::get)));
@@ -367,7 +367,7 @@ public class AccountsManager {
final UUID pni = account.getPhoneNumberIdentifier();
final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final List<Long> pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni);
final List<Long> pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni).join();
keysManager.delete(pni);
keysManager.storeEcSignedPreKeys(pni, pniSignedPreKeys);
if (pniPqLastResortPreKeys != null) {

View File

@@ -42,11 +42,11 @@ public class KeysManager {
this.dynamicConfigurationManager = dynamicConfigurationManager;
}
public void store(final UUID identifier, final long deviceId, final List<ECPreKey> keys) {
store(identifier, deviceId, keys, null, null, null);
public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final List<ECPreKey> keys) {
return store(identifier, deviceId, keys, null, null, null);
}
public void store(
public CompletableFuture<Void> store(
final UUID identifier, final long deviceId,
@Nullable final List<ECPreKey> ecKeys,
@Nullable final List<KEMSignedPreKey> pqKeys,
@@ -71,12 +71,14 @@ public class KeysManager {
storeFutures.add(pqLastResortKeys.store(identifier, deviceId, pqLastResortKey));
}
CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])).join();
return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0]));
}
public void storeEcSignedPreKeys(final UUID identifier, final Map<Long, ECSignedPreKey> keys) {
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final Map<Long, ECSignedPreKey> keys) {
if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) {
ecSignedPreKeys.store(identifier, keys).join();
return ecSignedPreKeys.store(identifier, keys);
} else {
return CompletableFuture.completedFuture(null);
}
}
@@ -84,40 +86,40 @@ public class KeysManager {
return ecSignedPreKeys.storeIfAbsent(identifier, deviceId, signedPreKey);
}
public void storePqLastResort(final UUID identifier, final Map<Long, KEMSignedPreKey> keys) {
pqLastResortKeys.store(identifier, keys).join();
public CompletableFuture<Void> storePqLastResort(final UUID identifier, final Map<Long, KEMSignedPreKey> keys) {
return pqLastResortKeys.store(identifier, keys);
}
public Optional<ECPreKey> takeEC(final UUID identifier, final long deviceId) {
return ecPreKeys.take(identifier, deviceId).join();
public CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final long deviceId) {
return ecPreKeys.take(identifier, deviceId);
}
public Optional<KEMSignedPreKey> takePQ(final UUID identifier, final long deviceId) {
public CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final long deviceId) {
return pqPreKeys.take(identifier, deviceId)
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))).join();
.orElseGet(() -> pqLastResortKeys.find(identifier, deviceId)));
}
@VisibleForTesting
Optional<KEMSignedPreKey> getLastResort(final UUID identifier, final long deviceId) {
return pqLastResortKeys.find(identifier, deviceId).join();
CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final long deviceId) {
return pqLastResortKeys.find(identifier, deviceId);
}
public CompletableFuture<Optional<ECSignedPreKey>> getEcSignedPreKey(final UUID identifier, final long deviceId) {
return ecSignedPreKeys.find(identifier, deviceId);
}
public List<Long> getPqEnabledDevices(final UUID identifier) {
return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().block();
public CompletableFuture<List<Long>> getPqEnabledDevices(final UUID identifier) {
return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().toFuture();
}
public int getEcCount(final UUID identifier, final long deviceId) {
return ecPreKeys.getCount(identifier, deviceId).join();
public CompletableFuture<Integer> getEcCount(final UUID identifier, final long deviceId) {
return ecPreKeys.getCount(identifier, deviceId);
}
public int getPqCount(final UUID identifier, final long deviceId) {
return pqPreKeys.getCount(identifier, deviceId).join();
public CompletableFuture<Integer> getPqCount(final UUID identifier, final long deviceId) {
return pqPreKeys.getCount(identifier, deviceId);
}
public void delete(final UUID accountUuid) {