PQXDH endpoints for chat server

This commit is contained in:
Jonathan Klabunde Tomer
2023-05-16 17:34:33 -04:00
committed by GitHub
parent 34d77e73ff
commit caae27c44c
30 changed files with 1378 additions and 380 deletions

View File

@@ -341,7 +341,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient, config.getDynamoDbTables().getKeys().getTableName());
Keys keys = new Keys(dynamoDbClient,
config.getDynamoDbTables().getEcKeys().getTableName(),
config.getDynamoDbTables().getPqKeys().getTableName(),
config.getDynamoDbTables().getPqLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration(),

View File

@@ -50,7 +50,9 @@ public class DynamoDbTables {
private final Table deletedAccounts;
private final Table deletedAccountsLock;
private final IssuedReceiptsTableConfiguration issuedReceipts;
private final Table keys;
private final Table ecKeys;
private final Table pqKeys;
private final Table pqLastResortKeys;
private final TableWithExpiration messages;
private final Table pendingAccounts;
private final Table pendingDevices;
@@ -69,7 +71,9 @@ public class DynamoDbTables {
@JsonProperty("deletedAccounts") final Table deletedAccounts,
@JsonProperty("deletedAccountsLock") final Table deletedAccountsLock,
@JsonProperty("issuedReceipts") final IssuedReceiptsTableConfiguration issuedReceipts,
@JsonProperty("keys") final Table keys,
@JsonProperty("ecKeys") final Table ecKeys,
@JsonProperty("pqKeys") final Table pqKeys,
@JsonProperty("pqLastResortKeys") final Table pqLastResortKeys,
@JsonProperty("messages") final TableWithExpiration messages,
@JsonProperty("pendingAccounts") final Table pendingAccounts,
@JsonProperty("pendingDevices") final Table pendingDevices,
@@ -87,7 +91,9 @@ public class DynamoDbTables {
this.deletedAccounts = deletedAccounts;
this.deletedAccountsLock = deletedAccountsLock;
this.issuedReceipts = issuedReceipts;
this.keys = keys;
this.ecKeys = ecKeys;
this.pqKeys = pqKeys;
this.pqLastResortKeys = pqLastResortKeys;
this.messages = messages;
this.pendingAccounts = pendingAccounts;
this.pendingDevices = pendingDevices;
@@ -128,8 +134,20 @@ public class DynamoDbTables {
@NotNull
@Valid
public Table getKeys() {
return keys;
public Table getEcKeys() {
return ecKeys;
}
@NotNull
@Valid
public Table getPqKeys() {
return pqKeys;
}
@NotNull
@Valid
public Table getPqLastResortKeys() {
return pqLastResortKeys;
}
@NotNull

View File

@@ -493,6 +493,7 @@ public class AccountController {
request.number(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());

View File

@@ -128,6 +128,7 @@ public class AccountControllerV2 {
request.number(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());
@@ -172,10 +173,11 @@ public class AccountControllerV2 {
}
try {
final Account updatedAccount = changeNumberManager.updatePNIKeys(
final Account updatedAccount = changeNumberManager.updatePniKeys(
authenticatedAccount.getAccount(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());

View File

@@ -11,14 +11,21 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.enums.ParameterIn;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
@@ -75,12 +82,14 @@ public class KeysController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Returns the number of available one-time prekeys for this device")
public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) {
int count = keys.getCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
int ecCount = keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
int pqCount = keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
return new PreKeyCount(count);
return new PreKeyCount(ecCount, pqCount);
}
@Timed
@@ -88,9 +97,17 @@ public class KeysController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
@Operation(summary = "Sets the identity key for the account or phone-number identity and/or prekeys for this device")
public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@NotNull @Valid final PreKeyState preKeys,
@RequestBody @NotNull @Valid final PreKeyState preKeys,
@Parameter(allowEmptyValue=true)
@Schema(
allowableValues={"aci", "pni"},
defaultValue="aci",
description="whether this operation applies to the account (aci) or phone-number (pni) identity")
@QueryParam("identity") final Optional<String> identityType,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
@@ -98,7 +115,8 @@ public class KeysController {
final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType);
if (!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) {
if (preKeys.getSignedPreKey() != null &&
!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) {
updateAccount = true;
}
@@ -121,13 +139,15 @@ public class KeysController {
if (updateAccount) {
account = accounts.update(account, a -> {
a.getDevice(device.getId()).ifPresent(d -> {
if (usePhoneNumberIdentity) {
d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey());
} else {
d.setSignedPreKey(preKeys.getSignedPreKey());
}
});
if (preKeys.getSignedPreKey() != null) {
a.getDevice(device.getId()).ifPresent(d -> {
if (usePhoneNumberIdentity) {
d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey());
} else {
d.setSignedPreKey(preKeys.getSignedPreKey());
}
});
}
if (usePhoneNumberIdentity) {
a.setPhoneNumberIdentityKey(preKeys.getIdentityKey());
@@ -137,17 +157,29 @@ public class KeysController {
});
}
keys.store(getIdentifier(account, identityType), device.getId(), preKeys.getPreKeys());
keys.store(
getIdentifier(account, identityType), device.getId(),
preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getPqLastResortPreKey());
}
@Timed
@GET
@Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Retrieves the public identity key and available device prekeys for a specified account or phone-number identity")
public Response getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@Parameter(description="the account or phone-number identifier to retrieve keys for")
@PathParam("identifier") UUID targetUuid,
@Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices")
@PathParam("device_id") String deviceId,
@Parameter(allowEmptyValue=true, description="whether to retrieve post-quantum prekeys")
@Schema(defaultValue="false")
@QueryParam("pq") boolean returnPqKey,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {
@@ -175,28 +207,30 @@ public class KeysController {
final boolean usePhoneNumberIdentity = target.getPhoneNumberIdentifier().equals(targetUuid);
Map<Long, PreKey> preKeysByDeviceId = getLocalKeys(target, deviceId, usePhoneNumberIdentity);
List<PreKeyResponseItem> responseItems = new LinkedList<>();
List<Device> devices = parseDeviceId(deviceId, target);
List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());
for (Device device : target.getDevices()) {
if (device.isEnabled() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
SignedPreKey signedPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
PreKey preKey = preKeysByDeviceId.get(device.getId());
for (Device device : devices) {
UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid;
SignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
PreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null);
SignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null;
if (signedPreKey != null || preKey != null) {
final int registrationId = usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId();
if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) {
final int registrationId = usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId();
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedPreKey, preKey));
}
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey));
}
}
final String identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey();
if (responseItems.isEmpty()) return Response.status(404).build();
else return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build();
if (responseItems.isEmpty()) {
return Response.status(404).build();
}
return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build();
}
@Timed
@@ -243,31 +277,15 @@ public class KeysController {
account.getUuid();
}
private Map<Long, PreKey> getLocalKeys(Account destination, String deviceIdSelector, final boolean usePhoneNumberIdentity) {
final Map<Long, PreKey> preKeys;
final UUID identifier = usePhoneNumberIdentity ?
destination.getPhoneNumberIdentifier() :
destination.getUuid();
if (deviceIdSelector.equals("*")) {
preKeys = new HashMap<>();
for (final Device device : destination.getDevices()) {
keys.take(identifier, device.getId()).ifPresent(preKey -> preKeys.put(device.getId(), preKey));
}
} else {
try {
long deviceId = Long.parseLong(deviceIdSelector);
preKeys = keys.take(identifier, deviceId)
.map(preKey -> Map.of(deviceId, preKey))
.orElse(Collections.emptyMap());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
private List<Device> parseDeviceId(String deviceId, Account account) {
if (deviceId.equals("*")) {
return account.getDevices().stream().filter(Device::isEnabled).toList();
}
try {
long id = Long.parseLong(deviceId);
return account.getDevice(id).filter(Device::isEnabled).map(List::of).orElse(List.of());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
return preKeys;
}
}

View File

@@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
@@ -16,21 +18,57 @@ import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record ChangeNumberRequest(String sessionId,
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
@NotBlank String number,
@JsonProperty("reglock") @Nullable String registrationLock,
@NotBlank String pniIdentityKey,
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
public record ChangeNumberRequest(
@Schema(description="""
A session ID from registration service, if using session id to authenticate this request.
Must not be combined with `recoveryPassword`.""")
String sessionId,
@Schema(description="""
The recovery password for the new phone number, if using a recovery password to authenticate this request.
Must not be combined with `sessionId`.""")
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
@Schema(description="the new phone number for this account")
@NotBlank String number,
@Schema(description="the registration lock password for the new phone number, if necessary")
@JsonProperty("reglock") @Nullable String registrationLock,
@Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number")
@NotBlank String pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@NotNull @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
@AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() {
if (devicePniSignedPrekeys == null) {
return true;
List<SignedPreKey> spks = new ArrayList<>();
if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values());
}
return devicePniSignedPrekeys.values().parallelStream()
.allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk));
if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks);
}
}

View File

@@ -6,27 +6,61 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.validation.constraints.AssertTrue;
import javax.annotation.Nullable;
import javax.validation.Valid;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
public record ChangePhoneNumberRequest(@NotBlank String number,
@NotBlank String code,
@JsonProperty("reglock") @Nullable String registrationLock,
@Nullable String pniIdentityKey,
@Nullable List<IncomingMessage> deviceMessages,
@Nullable Map<Long, SignedPreKey> devicePniSignedPrekeys,
@Nullable Map<Long, Integer> pniRegistrationIds) {
public record ChangePhoneNumberRequest(
@Schema(description="the new phone number for this account")
@NotBlank String number,
@Schema(description="the registration verification code to authenticate this request")
@NotBlank String code,
@Schema(description="the registration lock password for the new phone number, if necessary")
@JsonProperty("reglock") @Nullable String registrationLock,
@Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number")
@Nullable String pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@Nullable List<IncomingMessage> deviceMessages,
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Nullable Map<Long, SignedPreKey> devicePniSignedPrekeys,
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Nullable @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@Nullable Map<Long, Integer> pniRegistrationIds) {
@AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() {
if (devicePniSignedPrekeys == null) {
return true;
List<SignedPreKey> spks = new ArrayList<>();
if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values());
}
return devicePniSignedPrekeys.values().parallelStream()
.allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk));
if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks);
}
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
@@ -17,29 +18,45 @@ import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record PhoneNumberIdentityKeyDistributionRequest(
@NotBlank
@Schema(description="the new identity key for this account's phone-number identity")
String pniIdentityKey,
@NotBlank
@Schema(description="the new identity key for this account's phone-number identity")
String pniIdentityKey,
@NotNull
@Valid
@Schema(description="A message for each companion device to pass its new private keys")
List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull
@Valid
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull
@Valid
@Schema(description="The public key of a new signed elliptic-curve prekey pair for each device")
Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@NotNull
@Valid
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@NotNull
@Valid
@Schema(description="The new registration ID to use for the phone-number identity of each device")
Map<Long, Integer> pniRegistrationIds) {
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniPqLastResortPrekeys,
@NotNull
@Valid
@Schema(description="The new registration ID to use for the phone-number identity of each device")
Map<Long, Integer> pniRegistrationIds) {
@AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() {
return devicePniSignedPrekeys.values().parallelStream()
.allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk));
List<SignedPreKey> spks = new ArrayList<>(devicePniSignedPrekeys.values());
if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks);
}
}

View File

@@ -13,17 +13,17 @@ public class PreKey {
@JsonProperty
@NotNull
private long keyId;
private long keyId;
@JsonProperty
@NotEmpty
private String publicKey;
private String publicKey;
public PreKey() {}
public PreKey(long keyId, String publicKey)
{
this.keyId = keyId;
this.keyId = keyId;
this.publicKey = publicKey;
}
@@ -63,5 +63,4 @@ public class PreKey {
return ((int)this.keyId) ^ publicKey.hashCode();
}
}
}

View File

@@ -5,16 +5,22 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
public class PreKeyCount {
@Schema(description="the number of stored unsigned elliptic-curve prekeys for this device")
@JsonProperty
private int count;
public PreKeyCount(int count) {
this.count = count;
@Schema(description="the number of stored one-time post-quantum prekeys for this device")
@JsonProperty
private int pqCount;
public PreKeyCount(int ecCount, int pqCount) {
this.count = ecCount;
this.pqCount = pqCount;
}
public PreKeyCount() {}
@@ -22,4 +28,8 @@ public class PreKeyCount {
public int getCount() {
return count;
}
public int getPqCount() {
return pqCount;
}
}

View File

@@ -7,15 +7,18 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
public class PreKeyResponse {
@JsonProperty
@Schema(description="the public identity key for the requested identity")
private String identityKey;
@JsonProperty
@Schema(description="information about each requested device")
private List<PreKeyResponseItem> devices;
public PreKeyResponse() {}

View File

@@ -6,28 +6,39 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema;
public class PreKeyResponseItem {
@JsonProperty
@Schema(description="the device ID of the device to which this item pertains")
private long deviceId;
@JsonProperty
@Schema(description="the registration ID for the device")
private int registrationId;
@JsonProperty
@Schema(description="the signed elliptic-curve prekey for the device, if one has been set")
private SignedPreKey signedPreKey;
@JsonProperty
@Schema(description="an unsigned elliptic-curve prekey for the device, if any remain")
private PreKey preKey;
@JsonProperty
@Schema(description="a signed post-quantum prekey for the device " +
"(a one-time prekey if any remain, otherwise the last-resort prekey if one has been set)")
private SignedPreKey pqPreKey;
public PreKeyResponseItem() {}
public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey) {
this.deviceId = deviceId;
public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey, SignedPreKey pqPreKey) {
this.deviceId = deviceId;
this.registrationId = registrationId;
this.signedPreKey = signedPreKey;
this.preKey = preKey;
this.signedPreKey = signedPreKey;
this.preKey = preKey;
this.pqPreKey = pqPreKey;
}
@VisibleForTesting
@@ -40,6 +51,11 @@ public class PreKeyResponseItem {
return preKey;
}
@VisibleForTesting
public SignedPreKey getPqPreKey() {
return pqPreKey;
}
@VisibleForTesting
public int getRegistrationId() {
return registrationId;

View File

@@ -5,24 +5,38 @@
package org.whispersystems.textsecuregcm.entities;
import static com.codahale.metrics.MetricRegistry.name;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.util.Base64;
import java.util.Collection;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
public abstract class PreKeySignatureValidator {
public static final boolean validatePreKeySignature(final String identityKeyB64, final SignedPreKey spk) {
public static final Counter INVALID_SIGNATURE_COUNTER =
Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature"));
public static final boolean validatePreKeySignatures(final String identityKeyB64, final Collection<SignedPreKey> spks) {
try {
final byte[] identityKeyBytes = Base64.getDecoder().decode(identityKeyB64);
final byte[] prekeyBytes = Base64.getDecoder().decode(spk.getPublicKey());
final byte[] prekeySignatureBytes = Base64.getDecoder().decode(spk.getSignature());
final ECPublicKey identityKey = Curve.decodePoint(identityKeyBytes, 0);
return identityKey.verifySignature(prekeyBytes, prekeySignatureBytes);
final boolean success = spks.stream().allMatch(spk -> {
final byte[] prekeyBytes = Base64.getDecoder().decode(spk.getPublicKey());
final byte[] prekeySignatureBytes = Base64.getDecoder().decode(spk.getSignature());
return identityKey.verifySignature(prekeyBytes, prekeySignatureBytes);
});
if (!success) {
INVALID_SIGNATURE_COUNTER.increment();
}
return success;
} catch (IllegalArgumentException | InvalidKeyException e) {
Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature")).increment();
INVALID_SIGNATURE_COUNTER.increment();
return false;
}
}

View File

@@ -6,6 +6,8 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List;
import javax.validation.Valid;
import javax.validation.constraints.AssertTrue;
@@ -15,26 +17,59 @@ import javax.validation.constraints.NotNull;
public class PreKeyState {
@JsonProperty
@NotNull
@Valid
@Schema(description="A list of unsigned elliptic-curve prekeys to use for this device. " +
"If present and not empty, replaces all stored unsigned EC prekeys for the device; " +
"if absent or empty, any stored unsigned EC prekeys for the device are not deleted.")
private List<PreKey> preKeys;
@JsonProperty
@NotNull
@Valid
@Schema(description="An optional signed elliptic-curve prekey to use for this device. " +
"If present, replaces the stored signed elliptic-curve prekey for the device; " +
"if absent, the stored signed prekey is not deleted. " +
"If present, must have a valid signature from the identity key in this request.")
private SignedPreKey signedPreKey;
@JsonProperty
@Valid
@Schema(description="A list of signed post-quantum one-time prekeys to use for this device. " +
"Each key must have a valid signature from the identity key in this request. " +
"If present and not empty, replaces all stored unsigned PQ prekeys for the device; " +
"if absent or empty, any stored unsigned PQ prekeys for the device are not deleted.")
private List<SignedPreKey> pqPreKeys;
@JsonProperty
@Valid
@Schema(description="An optional signed last-resort post-quantum prekey to use for this device. " +
"If present, replaces the stored signed post-quantum last-resort prekey for the device; " +
"if absent, a stored last-resort prekey will *not* be deleted. " +
"If present, must have a valid signature from the identity key in this request.")
private SignedPreKey pqLastResortPreKey;
@JsonProperty
@NotEmpty
@NotNull
@Schema(description="Required. " +
"The public identity key for this identity (account or phone-number identity). " +
"If this device is not the primary device for the account, " +
"must match the existing stored identity key for this identity.")
private String identityKey;
public PreKeyState() {}
@VisibleForTesting
public PreKeyState(String identityKey, SignedPreKey signedPreKey, List<PreKey> keys) {
this.identityKey = identityKey;
this.signedPreKey = signedPreKey;
this.preKeys = keys;
this(identityKey, signedPreKey, keys, null, null);
}
@VisibleForTesting
public PreKeyState(String identityKey, SignedPreKey signedPreKey, List<PreKey> keys, List<SignedPreKey> pqKeys, SignedPreKey pqLastResortKey) {
this.identityKey = identityKey;
this.signedPreKey = signedPreKey;
this.preKeys = keys;
this.pqPreKeys = pqKeys;
this.pqLastResortPreKey = pqLastResortKey;
}
public List<PreKey> getPreKeys() {
@@ -45,12 +80,30 @@ public class PreKeyState {
return signedPreKey;
}
public List<SignedPreKey> getPqPreKeys() {
return pqPreKeys;
}
public SignedPreKey getPqLastResortPreKey() {
return pqLastResortPreKey;
}
public String getIdentityKey() {
return identityKey;
}
@AssertTrue
public boolean isSignatureValid() {
return PreKeySignatureValidator.validatePreKeySignature(identityKey, signedPreKey);
public boolean isSignatureValidOnEachSignedKey() {
List<SignedPreKey> spks = new ArrayList<>();
if (pqPreKeys != null) {
spks.addAll(pqPreKeys);
}
if (pqLastResortPreKey != null) {
spks.add(pqLastResortPreKey);
}
if (signedPreKey != null) {
spks.add(signedPreKey);
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(identityKey, spks);
}
}

View File

@@ -45,5 +45,4 @@ public class SignedPreKey extends PreKey {
return super.hashCode() ^ signature.hashCode();
}
}
}

View File

@@ -12,6 +12,7 @@ import static io.micrometer.core.instrument.Metrics.timer;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
@@ -53,7 +54,7 @@ public abstract class AbstractDynamoDbStore {
return dynamoDbClient;
}
protected void executeTableWriteItemsUntilComplete(final Map<String, List<WriteRequest>> items) {
protected void executeTableWriteItemsUntilComplete(final Map<String, ? extends Collection<WriteRequest>> items) {
final AtomicReference<BatchWriteItemResponse> outcome = new AtomicReference<>();
writeAndStoreOutcome(items, batchWriteItemsFirstPass, outcome);
int attemptCount = 0;
@@ -80,7 +81,7 @@ public abstract class AbstractDynamoDbStore {
}
private void writeAndStoreOutcome(
final Map<String, List<WriteRequest>> items,
final Map<String, ? extends Collection<WriteRequest>> items,
final Timer timer,
final AtomicReference<BatchWriteItemResponse> outcome) {
timer.record(

View File

@@ -245,6 +245,7 @@ public class AccountsManager {
public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final String originalNumber = account.getNumber();
@@ -252,12 +253,12 @@ public class AccountsManager {
if (originalNumber.equals(number)) {
if (pniIdentityKey != null) {
throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePNIKeys");
throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePniKeys");
}
return account;
}
validateDevices(account, pniSignedPreKeys, pniRegistrationIds);
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
final AtomicReference<Account> updatedAccount = new AtomicReference<>();
@@ -281,7 +282,7 @@ public class AccountsManager {
numberChangedAccount = updateWithRetries(
account,
a -> setPNIKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds),
a -> { setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); return true; },
a -> accounts.changeNumber(a, number, phoneNumberIdentifier),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
@@ -291,45 +292,74 @@ public class AccountsManager {
keys.delete(phoneNumberIdentifier);
keys.delete(originalPhoneNumberIdentifier);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(
phoneNumberIdentifier,
keys.getPqEnabledDevices(uuid).stream().collect(
Collectors.toMap(
Function.identity(),
pniPqLastResortPreKeys::get)));
}
return displacedUuid;
});
return updatedAccount.get();
}
public Account updatePNIKeys(final Account account,
public Account updatePniKeys(final Account account,
final String pniIdentityKey,
final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
validateDevices(account, pniSignedPreKeys, pniRegistrationIds);
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
return update(account, a -> { return setPNIKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final UUID pni = account.getPhoneNumberIdentifier();
final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final List<Long> pqEnabledDeviceIDs = keys.getPqEnabledDevices(pni);
keys.delete(pni);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get)));
}
return updatedAccount;
}
private boolean setPNIKeys(final Account account,
private boolean setPniKeys(final Account account,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
return true;
return false;
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null");
}
pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey)));
boolean changed = !pniIdentityKey.equals(account.getPhoneNumberIdentityKey());
for (Device device : account.getDevices()) {
if (!device.isEnabled()) {
continue;
}
SignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId());
int registrationId = pniRegistrationIds.get(device.getId());
changed = changed ||
!signedPreKey.equals(device.getPhoneNumberIdentitySignedPreKey()) ||
device.getRegistrationId() != registrationId;
device.setPhoneNumberIdentitySignedPreKey(signedPreKey);
device.setPhoneNumberIdentityRegistrationId(registrationId);
}
pniRegistrationIds.forEach((deviceId, registrationId) ->
account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId)));
account.setPhoneNumberIdentityKey(pniIdentityKey);
account.setPhoneNumberIdentityKey(pniIdentityKey);
return true;
return changed;
}
private void validateDevices(final Account account,
final Map<Long, SignedPreKey> pniSignedPreKeys,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
if (pniSignedPreKeys == null && pniRegistrationIds == null) {
return;
} else if (pniSignedPreKeys == null || pniRegistrationIds == null) {
@@ -342,6 +372,12 @@ public class AccountsManager {
pniSignedPreKeys.keySet(),
Collections.emptySet());
// Check that all including master ID are in Pq pre-keys
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniSignedPreKeys.keySet(),
Collections.emptySet());
// Check that all devices are accounted for in the map of new PNI registration IDs
DestinationDeviceValidator.validateCompleteDeviceList(
account,

View File

@@ -42,6 +42,7 @@ public class ChangeNumberManager {
public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys,
@Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Long, Integer> pniRegistrationIds)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException {
@@ -62,10 +63,14 @@ public class ChangeNumberManager {
// We don't need to actually do a number-change operation in our DB, but we *do* need to accept their new key
// material and distribute the sync messages, to be sure all clients agree with us and each other about what their
// keys are. Pretend this change-number request was actually a PNI key distribution request.
return updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds);
if (pniIdentityKey == null) {
return account;
}
return updatePniKeys(account, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, deviceMessages, pniRegistrationIds);
}
final Account updatedAccount = accountsManager.changeNumber(account, number, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds);
final Account updatedAccount = accountsManager.changeNumber(
account, number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds);
if (deviceMessages != null) {
sendDeviceMessages(updatedAccount, deviceMessages);
@@ -74,16 +79,18 @@ public class ChangeNumberManager {
return updatedAccount;
}
public Account updatePNIKeys(final Account account,
public Account updatePniKeys(final Account account,
final String pniIdentityKey,
final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException {
validateDeviceMessages(account, deviceMessages);
// Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb
// write anyway. Linked devices can handle some wasted extra key rotations.
final Account updatedAccount = accountsManager.updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds);
final Account updatedAccount = accountsManager.updatePniKeys(
account, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds);
sendDeviceMessages(updatedAccount, deviceMessages);
return updatedAccount;

View File

@@ -6,6 +6,9 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Multimap;
import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.Multimaps;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
@@ -16,7 +19,11 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@@ -34,11 +41,14 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class Keys extends AbstractDynamoDbStore {
private final String tableName;
private final String ecTableName;
private final String pqTableName;
private final String pqLastResortTableName;
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID_KEY_ID = "DK";
static final String KEY_PUBLIC_KEY = "P";
static final String KEY_SIGNATURE = "S";
private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(Keys.class, "storeKeys"));
private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "takeKeyForDevice"));
@@ -48,31 +58,114 @@ public class Keys extends AbstractDynamoDbStore {
private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION = Metrics.summary(name(Keys.class, "contestedKeys"));
private static final DistributionSummary KEY_COUNT_DISTRIBUTION = Metrics.summary(name(Keys.class, "keyCount"));
private static final Counter KEYS_EMPTY_TAKE_COUNTER = Metrics.counter(name(Keys.class, "takeKeyEmpty"));
private static final Counter TOO_MANY_LAST_RESORT_KEYS_COUNTER = Metrics.counter(name(Keys.class, "tooManyLastResortKeys"));
public Keys(final DynamoDbClient dynamoDB, final String tableName) {
public Keys(
final DynamoDbClient dynamoDB,
final String ecTableName,
final String pqTableName,
final String pqLastResortTableName) {
super(dynamoDB);
this.tableName = tableName;
this.ecTableName = ecTableName;
this.pqTableName = pqTableName;
this.pqLastResortTableName = pqLastResortTableName;
}
public void store(final UUID identifier, final long deviceId, final List<PreKey> keys) {
STORE_KEYS_TIMER.record(() -> {
delete(identifier, deviceId);
store(identifier, deviceId, keys, null, null);
}
writeInBatches(keys, batch -> {
List<WriteRequest> items = new ArrayList<>();
for (final PreKey preKey : batch) {
items.add(WriteRequest.builder()
.putRequest(PutRequest.builder()
.item(getItemFromPreKey(identifier, deviceId, preKey))
.build())
.build());
}
executeTableWriteItemsUntilComplete(Map.of(tableName, items));
public void store(
final UUID identifier, final long deviceId,
@Nullable final List<PreKey> ecKeys,
@Nullable final List<SignedPreKey> pqKeys,
@Nullable final SignedPreKey pqLastResortKey) {
Multimap<String, PreKey> keys = MultimapBuilder.hashKeys().arrayListValues().build();
List<String> tablesToClear = new ArrayList<>();
if (ecKeys != null && !ecKeys.isEmpty()) {
keys.putAll(ecTableName, ecKeys);
tablesToClear.add(ecTableName);
}
if (pqKeys != null && !pqKeys.isEmpty()) {
keys.putAll(pqTableName, pqKeys);
tablesToClear.add(pqTableName);
}
if (pqLastResortKey != null) {
keys.put(pqLastResortTableName, pqLastResortKey);
tablesToClear.add(pqLastResortTableName);
}
STORE_KEYS_TIMER.record(() -> {
delete(tablesToClear, identifier, deviceId);
writeInBatches(
keys.entries(),
batch -> {
Multimap<String, WriteRequest> writes = batch.stream()
.collect(
Multimaps.toMultimap(
Map.Entry<String, PreKey>::getKey,
entry -> WriteRequest.builder()
.putRequest(PutRequest.builder()
.item(getItemFromPreKey(identifier, deviceId, entry.getValue()))
.build())
.build(),
MultimapBuilder.hashKeys().arrayListValues()::build));
executeTableWriteItemsUntilComplete(writes.asMap());
});
});
}
public Optional<PreKey> take(final UUID identifier, final long deviceId) {
public void storePqLastResort(final UUID identifier, final Map<Long, SignedPreKey> keys) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
final List<WriteRequest> writes = new ArrayList<>(2 * keys.size());
final Map<Long, Map<String, AttributeValue>> newItems = keys.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> getItemFromPreKey(identifier, e.getKey(), e.getValue())));
for (final Map<String, AttributeValue> item : db().query(queryRequest).items()) {
final AttributeValue oldSortKey = item.get(KEY_DEVICE_ID_KEY_ID);
final Long oldDeviceId = oldSortKey.b().asByteBuffer().getLong();
if (newItems.containsKey(oldDeviceId)) {
final Map<String, AttributeValue> replacement = newItems.get(oldDeviceId);
if (!replacement.get(KEY_DEVICE_ID_KEY_ID).equals(oldSortKey)) {
writes.add(WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, oldSortKey))
.build())
.build());
}
}
}
newItems.forEach((unusedKey, item) ->
writes.add(WriteRequest.builder().putRequest(PutRequest.builder().item(item).build()).build()));
executeTableWriteItemsUntilComplete(Map.of(pqLastResortTableName, writes));
}
public Optional<PreKey> takeEC(final UUID identifier, final long deviceId) {
return take(ecTableName, identifier, deviceId);
}
public Optional<SignedPreKey> takePQ(final UUID identifier, final long deviceId) {
return take(pqTableName, identifier, deviceId)
.or(() -> getLastResort(identifier, deviceId))
.map(pk -> (SignedPreKey) pk);
}
private Optional<PreKey> take(final String tableName, final UUID identifier, final long deviceId) {
return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> {
final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder()
@@ -114,7 +207,53 @@ public class Keys extends AbstractDynamoDbStore {
});
}
public int getCount(final UUID identifier, final long deviceId) {
@VisibleForTesting
Optional<PreKey> getLastResort(final UUID identifier, final long deviceId) {
final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", partitionKey,
":sortprefix", getSortKeyPrefix(deviceId)))
.consistentRead(false)
.select(Select.ALL_ATTRIBUTES)
.build();
QueryResponse response = db().query(queryRequest);
if (response.count() > 1) {
TOO_MANY_LAST_RESORT_KEYS_COUNTER.increment();
}
return response.items().stream().findFirst().map(this::getPreKeyFromItem);
}
public List<Long> getPqEnabledDevices(final UUID identifier) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(false)
.build();
final QueryResponse response = db().query(queryRequest);
return response.items().stream()
.map(item -> item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong())
.toList();
}
public int getEcCount(final UUID identifier, final long deviceId) {
return getCount(ecTableName, identifier, deviceId);
}
public int getPqCount(final UUID identifier, final long deviceId) {
return getCount(pqTableName, identifier, deviceId);
}
private int getCount(final String tableName, final UUID identifier, final long deviceId) {
return GET_KEY_COUNT_TIMER.record(() -> {
QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
@@ -144,51 +283,66 @@ public class Keys extends AbstractDynamoDbStore {
public void delete(final UUID accountUuid) {
DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid)))
":uuid", getPartitionKey(accountUuid)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
deleteItemsForAccountMatchingQuery(accountUuid, queryRequest);
deleteItemsForAccountMatchingQuery(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, queryRequest);
});
}
public void delete(final UUID accountUuid, final long deviceId) {
delete(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, deviceId);
}
private void delete(final List<String> tableNames, final UUID accountUuid, final long deviceId) {
DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid),
":sortprefix", getSortKeyPrefix(deviceId)))
":uuid", getPartitionKey(accountUuid),
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
deleteItemsForAccountMatchingQuery(accountUuid, queryRequest);
deleteItemsForAccountMatchingQuery(tableNames, accountUuid, queryRequest);
});
}
private void deleteItemsForAccountMatchingQuery(final UUID accountUuid, final QueryRequest querySpec) {
private void deleteItemsForAccountMatchingQuery(final List<String> tableNames, final UUID accountUuid, final QueryRequest querySpec) {
final AttributeValue partitionKey = getPartitionKey(accountUuid);
writeInBatches(db().query(querySpec).items(), batch -> {
List<WriteRequest> deletes = new ArrayList<>();
for (final Map<String, AttributeValue> item : batch) {
deletes.add(WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID)))
.build())
.build());
}
executeTableWriteItemsUntilComplete(Map.of(tableName, deletes));
Multimap<String, Map<String, AttributeValue>> itemStream = tableNames.stream()
.collect(
Multimaps.flatteningToMultimap(
Function.identity(),
tableName ->
db().query(querySpec.toBuilder().tableName(tableName).build())
.items()
.stream(),
MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
writeInBatches(
itemStream.entries(),
batch -> {
Multimap<String, WriteRequest> deletes = batch.stream()
.collect(Multimaps.toMultimap(
Map.Entry<String, Map<String, AttributeValue>>::getKey,
entry -> WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, entry.getValue().get(KEY_DEVICE_ID_KEY_ID)))
.build())
.build(),
MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
executeTableWriteItemsUntilComplete(deletes.asMap());
});
}
@@ -211,6 +365,13 @@ public class Keys extends AbstractDynamoDbStore {
}
private Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) {
if (preKey instanceof final SignedPreKey spk) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, spk.getKeyId()),
KEY_PUBLIC_KEY, AttributeValues.fromString(spk.getPublicKey()),
KEY_SIGNATURE, AttributeValues.fromString(spk.getSignature()));
}
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()),
@@ -219,6 +380,11 @@ public class Keys extends AbstractDynamoDbStore {
private PreKey getPreKeyFromItem(Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
if (item.containsKey(KEY_SIGNATURE)) {
// All PQ prekeys are signed, and therefore have this attribute. Signed EC prekeys are stored
// in the Accounts table, so EC prekeys retrieved by this class are never SignedPreKeys.
return new SignedPreKey(keyId, item.get(KEY_PUBLIC_KEY).s(), item.get(KEY_SIGNATURE).s());
}
return new PreKey(keyId, item.get(KEY_PUBLIC_KEY).s());
}
}

View File

@@ -174,7 +174,9 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
configuration.getDynamoDbTables().getKeys().getTableName());
configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@@ -154,7 +154,9 @@ record CommandDependencies(
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
configuration.getDynamoDbTables().getKeys().getTableName());
configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),