PQXDH endpoints for chat server

This commit is contained in:
Jonathan Klabunde Tomer
2023-05-16 17:34:33 -04:00
committed by GitHub
parent 34d77e73ff
commit caae27c44c
30 changed files with 1378 additions and 380 deletions

View File

@@ -493,6 +493,7 @@ public class AccountController {
request.number(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());

View File

@@ -128,6 +128,7 @@ public class AccountControllerV2 {
request.number(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());
@@ -172,10 +173,11 @@ public class AccountControllerV2 {
}
try {
final Account updatedAccount = changeNumberManager.updatePNIKeys(
final Account updatedAccount = changeNumberManager.updatePniKeys(
authenticatedAccount.getAccount(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());

View File

@@ -11,14 +11,21 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.enums.ParameterIn;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
@@ -75,12 +82,14 @@ public class KeysController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Returns the number of available one-time prekeys for this device")
public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) {
int count = keys.getCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
int ecCount = keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
int pqCount = keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
return new PreKeyCount(count);
return new PreKeyCount(ecCount, pqCount);
}
@Timed
@@ -88,9 +97,17 @@ public class KeysController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
@Operation(summary = "Sets the identity key for the account or phone-number identity and/or prekeys for this device")
public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@NotNull @Valid final PreKeyState preKeys,
@RequestBody @NotNull @Valid final PreKeyState preKeys,
@Parameter(allowEmptyValue=true)
@Schema(
allowableValues={"aci", "pni"},
defaultValue="aci",
description="whether this operation applies to the account (aci) or phone-number (pni) identity")
@QueryParam("identity") final Optional<String> identityType,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
@@ -98,7 +115,8 @@ public class KeysController {
final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType);
if (!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) {
if (preKeys.getSignedPreKey() != null &&
!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) {
updateAccount = true;
}
@@ -121,13 +139,15 @@ public class KeysController {
if (updateAccount) {
account = accounts.update(account, a -> {
a.getDevice(device.getId()).ifPresent(d -> {
if (usePhoneNumberIdentity) {
d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey());
} else {
d.setSignedPreKey(preKeys.getSignedPreKey());
}
});
if (preKeys.getSignedPreKey() != null) {
a.getDevice(device.getId()).ifPresent(d -> {
if (usePhoneNumberIdentity) {
d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey());
} else {
d.setSignedPreKey(preKeys.getSignedPreKey());
}
});
}
if (usePhoneNumberIdentity) {
a.setPhoneNumberIdentityKey(preKeys.getIdentityKey());
@@ -137,17 +157,29 @@ public class KeysController {
});
}
keys.store(getIdentifier(account, identityType), device.getId(), preKeys.getPreKeys());
keys.store(
getIdentifier(account, identityType), device.getId(),
preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getPqLastResortPreKey());
}
@Timed
@GET
@Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Retrieves the public identity key and available device prekeys for a specified account or phone-number identity")
public Response getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@Parameter(description="the account or phone-number identifier to retrieve keys for")
@PathParam("identifier") UUID targetUuid,
@Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices")
@PathParam("device_id") String deviceId,
@Parameter(allowEmptyValue=true, description="whether to retrieve post-quantum prekeys")
@Schema(defaultValue="false")
@QueryParam("pq") boolean returnPqKey,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {
@@ -175,28 +207,30 @@ public class KeysController {
final boolean usePhoneNumberIdentity = target.getPhoneNumberIdentifier().equals(targetUuid);
Map<Long, PreKey> preKeysByDeviceId = getLocalKeys(target, deviceId, usePhoneNumberIdentity);
List<PreKeyResponseItem> responseItems = new LinkedList<>();
List<Device> devices = parseDeviceId(deviceId, target);
List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());
for (Device device : target.getDevices()) {
if (device.isEnabled() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
SignedPreKey signedPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
PreKey preKey = preKeysByDeviceId.get(device.getId());
for (Device device : devices) {
UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid;
SignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
PreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null);
SignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null;
if (signedPreKey != null || preKey != null) {
final int registrationId = usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId();
if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) {
final int registrationId = usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId();
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedPreKey, preKey));
}
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey));
}
}
final String identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey();
if (responseItems.isEmpty()) return Response.status(404).build();
else return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build();
if (responseItems.isEmpty()) {
return Response.status(404).build();
}
return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build();
}
@Timed
@@ -243,31 +277,15 @@ public class KeysController {
account.getUuid();
}
private Map<Long, PreKey> getLocalKeys(Account destination, String deviceIdSelector, final boolean usePhoneNumberIdentity) {
final Map<Long, PreKey> preKeys;
final UUID identifier = usePhoneNumberIdentity ?
destination.getPhoneNumberIdentifier() :
destination.getUuid();
if (deviceIdSelector.equals("*")) {
preKeys = new HashMap<>();
for (final Device device : destination.getDevices()) {
keys.take(identifier, device.getId()).ifPresent(preKey -> preKeys.put(device.getId(), preKey));
}
} else {
try {
long deviceId = Long.parseLong(deviceIdSelector);
preKeys = keys.take(identifier, deviceId)
.map(preKey -> Map.of(deviceId, preKey))
.orElse(Collections.emptyMap());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
private List<Device> parseDeviceId(String deviceId, Account account) {
if (deviceId.equals("*")) {
return account.getDevices().stream().filter(Device::isEnabled).toList();
}
try {
long id = Long.parseLong(deviceId);
return account.getDevice(id).filter(Device::isEnabled).map(List::of).orElse(List.of());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
return preKeys;
}
}