Remove signed pre-keys from Device entities

This commit is contained in:
Jon Chambers
2023-12-08 18:43:35 -05:00
committed by Jon Chambers
parent 394f9929ad
commit b048b0bf65
14 changed files with 123 additions and 233 deletions

View File

@@ -15,9 +15,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -41,6 +39,7 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
@@ -60,8 +59,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v2/keys")
@@ -69,16 +66,16 @@ import reactor.core.publisher.Mono;
public class KeysController {
private final RateLimiters rateLimiters;
private final KeysManager keys;
private final KeysManager keysManager;
private final AccountsManager accounts;
private static final String GET_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "getKeys");
private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0];
public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) {
public KeysController(RateLimiters rateLimiters, KeysManager keysManager, AccountsManager accounts) {
this.rateLimiters = rateLimiters;
this.keys = keys;
this.keysManager = keysManager;
this.accounts = accounts;
}
@@ -92,10 +89,10 @@ public class KeysController {
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
final CompletableFuture<Integer> ecCountFuture =
keys.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
keysManager.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
final CompletableFuture<Integer> pqCountFuture =
keys.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
keysManager.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new);
}
@@ -124,43 +121,25 @@ public class KeysController {
checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType));
final CompletableFuture<Account> updateAccountFuture;
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
if (setKeysRequest.signedPreKey() != null &&
!setKeysRequest.signedPreKey().equals(device.getSignedPreKey(identityType))) {
updateAccountFuture = accounts.updateDeviceTransactionallyAsync(account,
device.getId(),
d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey());
case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey());
}
},
d -> List.of(keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), setKeysRequest.signedPreKey())))
.toCompletableFuture();
} else {
updateAccountFuture = CompletableFuture.completedFuture(account);
if (setKeysRequest.preKeys() != null && !setKeysRequest.preKeys().isEmpty()) {
storeFutures.add(keysManager.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys()));
}
return updateAccountFuture.thenCompose(updatedAccount -> {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(3);
if (setKeysRequest.signedPreKey() != null) {
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
}
if (setKeysRequest.preKeys() != null && !setKeysRequest.preKeys().isEmpty()) {
storeFutures.add(keys.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys()));
}
if (setKeysRequest.pqPreKeys() != null && !setKeysRequest.pqPreKeys().isEmpty()) {
storeFutures.add(keysManager.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys()));
}
if (setKeysRequest.pqPreKeys() != null && !setKeysRequest.pqPreKeys().isEmpty()) {
storeFutures.add(keys.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys()));
}
if (setKeysRequest.pqLastResortPreKey() != null) {
storeFutures.add(keysManager.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey()));
}
if (setKeysRequest.pqLastResortPreKey() != null) {
storeFutures.add(
keys.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey()));
}
return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY));
})
return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY))
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
@@ -240,28 +219,41 @@ public class KeysController {
io.micrometer.core.instrument.Tag.of("wildcardDeviceId", String.valueOf("*".equals(deviceId)))))
.increment();
final List<PreKeyResponseItem> responseItems = Flux.fromIterable(parseDeviceId(deviceId, target))
.flatMap(device -> Mono.zip(
Mono.just(device),
Mono.fromFuture(() -> keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())),
Mono.fromFuture(() -> keys.takeEC(targetIdentifier.uuid(), device.getId())),
Mono.fromFuture(() -> returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId())
: CompletableFuture.<Optional<KEMSignedPreKey>>completedFuture(Optional.empty()))
)).filter(keys -> keys.getT2().isPresent() || keys.getT3().isPresent() || keys.getT4().isPresent())
.map(deviceAndKeys -> {
final Device device = deviceAndKeys.getT1();
final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};
return new PreKeyResponseItem(device.getId(), registrationId,
deviceAndKeys.getT2().orElse(null),
deviceAndKeys.getT3().orElse(null),
deviceAndKeys.getT4().orElse(null));
}).collectList()
.timeout(Duration.ofSeconds(30))
.blockOptional()
.orElse(Collections.emptyList());
final List<Device> devices = parseDeviceId(deviceId, target);
final List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());
final List<CompletableFuture<Void>> tasks = devices.stream().map(device -> {
final CompletableFuture<Optional<ECPreKey>> unsignedEcPreKeyFuture =
keysManager.takeEC(targetIdentifier.uuid(), device.getId());
final CompletableFuture<Optional<ECSignedPreKey>> signedEcPreKeyFuture =
keysManager.getEcSignedPreKey(targetIdentifier.uuid(), device.getId());
final CompletableFuture<Optional<KEMSignedPreKey>> pqPreKeyFuture = returnPqKey
? keysManager.takePQ(targetIdentifier.uuid(), device.getId())
: CompletableFuture.completedFuture(Optional.empty());
return CompletableFuture.allOf(unsignedEcPreKeyFuture, signedEcPreKeyFuture, pqPreKeyFuture)
.thenAccept(ignored -> {
final KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null);
final ECPreKey unsignedEcPreKey = unsignedEcPreKeyFuture.join().orElse(null);
final ECSignedPreKey signedEcPreKey = signedEcPreKeyFuture.join().orElse(null);
if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};
responseItems.add(
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey,
pqPreKey));
}
});
})
.toList();
CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join();
final IdentityKey identityKey = target.getIdentityKey(targetIdentifier.identityType());
@@ -289,16 +281,7 @@ public class KeysController {
final UUID identifier = auth.getAccount().getIdentifier(identityType);
final byte deviceId = auth.getAuthenticatedDevice().getId();
return accounts.updateDeviceTransactionallyAsync(auth.getAccount(),
deviceId,
d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(signedPreKey);
case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey);
}
},
d -> List.of(keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), signedPreKey)))
.toCompletableFuture()
return keysManager.storeEcSignedPreKeys(identifier, deviceId, signedPreKey)
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}

View File

@@ -12,7 +12,6 @@ import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey;
@@ -40,7 +39,6 @@ import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@@ -191,18 +189,9 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase {
KeysGrpcService::checkEcSignedPreKey,
(account, signedPreKey) -> {
final IdentityType identityType = IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType());
final Consumer<Device> deviceUpdater = switch (identityType) {
case ACI -> device -> device.setSignedPreKey(signedPreKey);
case PNI -> device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey);
};
final UUID identifier = account.getIdentifier(identityType);
return Flux.merge(
Mono.fromFuture(() -> keysManager.storeEcSignedPreKeys(identifier, authenticatedDevice.deviceId(), signedPreKey)),
Mono.fromFuture(() -> accountsManager.updateDeviceAsync(account, authenticatedDevice.deviceId(), deviceUpdater)))
.then();
return Mono.fromFuture(() -> keysManager.storeEcSignedPreKeys(identifier, authenticatedDevice.deviceId(), signedPreKey));
}));
}

View File

@@ -30,7 +30,6 @@ import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.UUID;
@@ -416,7 +415,7 @@ public class AccountsManager {
final Account numberChangedAccount = updateWithRetries(
account,
a -> {
setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds);
setPniKeys(account, pniIdentityKey, pniRegistrationIds);
return true;
},
a -> accounts.changeNumber(a, targetNumber, phoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems),
@@ -445,7 +444,7 @@ public class AccountsManager {
return redisDeleteAsync(account)
.thenCompose(ignored -> keysManager.deleteSingleUsePreKeys(pni))
.thenCompose(ignored -> updateTransactionallyWithRetriesAsync(account,
a -> setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds),
a -> setPniKeys(a, pniIdentityKey, pniRegistrationIds),
accounts::updateTransactionallyAsync,
() -> accounts.getByAccountIdentifierAsync(aci).thenApply(Optional::orElseThrow),
a -> keyWriteItems,
@@ -483,28 +482,18 @@ public class AccountsManager {
private void setPniKeys(final Account account,
@Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
if (ObjectUtils.allNull(pniIdentityKey, pniRegistrationIds)) {
return;
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null");
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key and registration IDs must be all null or all non-null");
}
boolean changed = !Objects.equals(pniIdentityKey, account.getIdentityKey(IdentityType.PNI));
for (Device device : account.getDevices()) {
if (!device.isEnabled()) {
continue;
}
ECSignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId());
int registrationId = pniRegistrationIds.get(device.getId());
changed = changed ||
!signedPreKey.equals(device.getSignedPreKey(IdentityType.PNI)) ||
device.getRegistrationId() != registrationId;
device.setPhoneNumberIdentitySignedPreKey(signedPreKey);
device.setPhoneNumberIdentityRegistrationId(registrationId);
}
account.getDevices()
.stream()
.filter(Device::isEnabled)
.forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId())));
account.setPhoneNumberIdentityKey(pniIdentityKey);
}

View File

@@ -17,8 +17,6 @@ import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.util.DeviceNameByteArrayAdapter;
public class Device {
@@ -72,12 +70,6 @@ public class Device {
@JsonProperty("pniRegistrationId")
private Integer phoneNumberIdentityRegistrationId;
@JsonProperty
private ECSignedPreKey signedPreKey;
@JsonProperty("pniSignedPreKey")
private ECSignedPreKey phoneNumberIdentitySignedPreKey;
@JsonProperty
private long lastSeen;
@@ -247,25 +239,6 @@ public class Device {
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
}
/**
* @deprecated Please retrieve signed pre-keys via {@link KeysManager#getEcSignedPreKey(UUID, byte)} instead
*/
@Deprecated
public ECSignedPreKey getSignedPreKey(final IdentityType identityType) {
return switch (identityType) {
case ACI -> signedPreKey;
case PNI -> phoneNumberIdentitySignedPreKey;
};
}
public void setSignedPreKey(ECSignedPreKey signedPreKey) {
this.signedPreKey = signedPreKey;
}
public void setPhoneNumberIdentitySignedPreKey(final ECSignedPreKey phoneNumberIdentitySignedPreKey) {
this.phoneNumberIdentitySignedPreKey = phoneNumberIdentitySignedPreKey;
}
public long getPushTimestamp() {
return pushTimestamp;
}

View File

@@ -38,8 +38,6 @@ public record DeviceSpec(
device.setCreated(clock.millis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent());
device.setSignedPreKey(aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey());
apnRegistrationId().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());