diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java index 614ff5336..510707386 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java @@ -32,6 +32,7 @@ import jakarta.ws.rs.core.Response; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.Optional; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; @@ -100,6 +101,7 @@ public class RegistrationController { @ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header( name = "Retry-After", description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) + @ApiResponse(responseCode = "499", description = "Client must support post-quantum ratchet") public AccountCreationResponse register( @HeaderParam(HttpHeaders.AUTHORIZATION) @NotNull final BasicAuthorizationHeader authorizationHeader, @HeaderParam(HeaderUtils.X_SIGNAL_AGENT) final String signalAgent, @@ -114,6 +116,13 @@ public class RegistrationController { throw new WebApplicationException("Invalid signature", 422); } + if (!(registrationRequest.accountAttributes().getCapabilities() != null + ? registrationRequest.accountAttributes().getCapabilities() + : Collections.emptySet()).containsAll(DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION)) { + + throw new WebApplicationException("Missing required device capability", 499); + } + rateLimiters.getRegistrationLimiter().validate(number); final PhoneVerificationRequest.VerificationType verificationType = phoneVerificationTokenManager.verify( diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceCapability.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceCapability.java index e21e68562..d6de1aa04 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceCapability.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceCapability.java @@ -5,13 +5,15 @@ package org.whispersystems.textsecuregcm.storage; +import java.util.Arrays; +import java.util.Collection; import java.util.Optional; public enum DeviceCapability { - STORAGE("storage", AccountCapabilityMode.ANY_DEVICE, false, false), - TRANSFER("transfer", AccountCapabilityMode.PRIMARY_DEVICE, false, false), - ATTACHMENT_BACKFILL("attachmentBackfill", AccountCapabilityMode.PRIMARY_DEVICE, false, true), - SPARSE_POST_QUANTUM_RATCHET("spqr", AccountCapabilityMode.ALL_DEVICES, false, true); + STORAGE("storage", AccountCapabilityMode.ANY_DEVICE, false, false, false), + TRANSFER("transfer", AccountCapabilityMode.PRIMARY_DEVICE, false, false, false), + ATTACHMENT_BACKFILL("attachmentBackfill", AccountCapabilityMode.PRIMARY_DEVICE, false, true, false), + SPARSE_POST_QUANTUM_RATCHET("spqr", AccountCapabilityMode.ALL_DEVICES, false, true, true); public enum AccountCapabilityMode { /** @@ -33,32 +35,42 @@ public enum DeviceCapability { ALWAYS_CAPABLE, } + public static final Collection CAPABILITIES_REQUIRED_FOR_REGISTRATION = + Arrays.stream(DeviceCapability.values()) + .filter(DeviceCapability::requireForRegistration) + .toList(); + private final String name; private final AccountCapabilityMode accountCapabilityMode; private final boolean preventDowngrade; private final boolean includeInProfile; + private final boolean requireForRegistration; /** * Create a DeviceCapability * - * @param name The name of the device capability that clients will see - * @param accountCapabilityMode How to combine the constituent device's capabilities in the account to an overall - * account capability - * @param preventDowngrade If true, don't let linked devices join that don't have a device capability if the - * overall account has the capability. Most of the time this should only be used in - * conjunction with AccountCapabilityMode.ALL_DEVICES - * @param includeInProfile Whether to return this capability on the account's profile. If false, the capability - * is only visible to the server + * @param name The name of the device capability that clients will see + * @param accountCapabilityMode How to combine the constituent device's capabilities in the account to an overall + * account capability + * @param preventDowngrade If true, don't let linked devices join that don't have a device capability if the + * overall account has the capability. Most of the time this should only be used in + * conjunction with AccountCapabilityMode.ALL_DEVICES. + * @param includeInProfile Whether to return this capability on the account's profile. If false, the capability + * is only visible to the server. + * @param requireForRegistration If true, prevent account creation if the account's initial device does not have this + * capability */ DeviceCapability(final String name, final AccountCapabilityMode accountCapabilityMode, final boolean preventDowngrade, - final boolean includeInProfile) { + final boolean includeInProfile, + final boolean requireForRegistration) { this.name = name; this.accountCapabilityMode = accountCapabilityMode; this.preventDowngrade = preventDowngrade; this.includeInProfile = includeInProfile; + this.requireForRegistration = requireForRegistration; } public String getName() { @@ -77,6 +89,10 @@ public enum DeviceCapability { return includeInProfile; } + public boolean requireForRegistration() { + return requireForRegistration; + } + public static Optional forName(final String name) { for (final DeviceCapability capability : DeviceCapability.values()) { if (capability.getName().equals(name)) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index d4c57a19f..44a11e188 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -91,6 +91,8 @@ import org.whispersystems.textsecuregcm.util.SystemMapper; class RegistrationControllerTest { private static final long SESSION_EXPIRATION_SECONDS = Duration.ofMinutes(10).toSeconds(); + private static final Set SPQR_DEVICE_CAPABILITIES = + Set.of(DeviceCapability.SPARSE_POST_QUANTUM_RATCHET); private static final String NUMBER = PhoneNumberUtil.getInstance().format( PhoneNumberUtil.getInstance().getExampleNumber("US"), @@ -538,10 +540,12 @@ class RegistrationControllerTest { } final AccountAttributes fetchesMessagesAccountAttributes = - new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, Set.of()); + new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, + SPQR_DEVICE_CAPABILITIES); final AccountAttributes pushAccountAttributes = - new AccountAttributes(false, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, Set.of()); + new AccountAttributes(false, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, + SPQR_DEVICE_CAPABILITIES); return List.of( Arguments.argumentSet("\"Fetches messages\" is true, but an APNs token is provided", @@ -627,7 +631,8 @@ class RegistrationControllerTest { } final AccountAttributes accountAttributes = - new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, Set.of()); + new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, + SPQR_DEVICE_CAPABILITIES); return List.of( Arguments.argumentSet("Signed PNI EC pre-key is missing", @@ -792,6 +797,30 @@ class RegistrationControllerTest { } } + @Test + void registrationMissingSpqrCapability() throws Exception { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn( + CompletableFuture.completedFuture( + Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, + SESSION_EXPIRATION_SECONDS)))); + + final Account account = mock(Account.class); + when(account.getPrimaryDevice()).thenReturn(mock(Device.class)); + + when(accountsManager.create(any(), any(), any(), any(), any(), any(), any())) + .thenReturn(account); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); + final RegistrationRequest requestObj = request("sessionId", new byte[0], false, 1, 2, Collections.emptySet()); + try (final Response response = request.post(Entity.json(requestToJson(requestObj)))) { + assertEquals(499, response.getStatus()); + } + } + private static boolean accountAttributesEqual(final AccountAttributes a, final AccountAttributes b) { return a.getFetchesMessages() == b.getFetchesMessages() && a.getRegistrationId() == b.getRegistrationId() @@ -828,7 +857,7 @@ class RegistrationControllerTest { final int registrationId = 1; final int pniRegistrationId = 2; - final Set deviceCapabilities = Set.of(); + final Set deviceCapabilities = SPQR_DEVICE_CAPABILITIES; final AccountAttributes fetchesMessagesAccountAttributes = new AccountAttributes(true, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, deviceCapabilities); @@ -932,25 +961,24 @@ class RegistrationControllerTest { ); } - /** - * Valid request JSON with the give session ID and skipDeviceTransfer - */ - private static String requestJson(final String sessionId, + private static RegistrationRequest request( + final String sessionId, final byte[] recoveryPassword, final boolean skipDeviceTransfer, final int registrationId, - int pniRegistrationId) { - + int pniRegistrationId, + Set deviceCapabilities) { final ECKeyPair aciIdentityKeyPair = ECKeyPair.generate(); final ECKeyPair pniIdentityKeyPair = ECKeyPair.generate(); final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final AccountAttributes accountAttributes = new AccountAttributes(true, registrationId, pniRegistrationId, "name".getBytes(StandardCharsets.UTF_8), "reglock", - true, Set.of()); + final AccountAttributes accountAttributes = new AccountAttributes(true, registrationId, pniRegistrationId, + "name".getBytes(StandardCharsets.UTF_8), "reglock", + true, deviceCapabilities); - final RegistrationRequest request = new RegistrationRequest( + return new RegistrationRequest( Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)), recoveryPassword, accountAttributes, @@ -964,6 +992,9 @@ class RegistrationControllerTest { KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair), Optional.empty(), Optional.empty())); + } + + private static String requestToJson(RegistrationRequest request) { try { return SystemMapper.jsonMapper().writerWithDefaultPrettyPrinter().writeValueAsString(request); } catch (final JsonProcessingException e) { @@ -971,6 +1002,17 @@ class RegistrationControllerTest { } } + /** + * Valid request JSON with the give session ID and skipDeviceTransfer + */ + private static String requestJson(final String sessionId, + final byte[] recoveryPassword, + final boolean skipDeviceTransfer, + final int registrationId, + final int pniRegistrationId) { + return requestToJson(request(sessionId, recoveryPassword, skipDeviceTransfer, registrationId, pniRegistrationId, SPQR_DEVICE_CAPABILITIES)); + } + /** * Valid request JSON with the given session ID */