From 575280da748290799a7193de66299d467331d9c8 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 6 Mar 2026 12:42:04 -0500 Subject: [PATCH] Don't allow linked devices that are missing capabilities required at registration time --- .../java/org/signal/integration/TestUser.java | 3 +- .../controllers/DeviceController.java | 4 ++ .../controllers/RegistrationController.java | 2 +- .../storage/DeviceCapability.java | 33 +++++------ .../controllers/DeviceControllerTest.java | 59 +++++++++++++++++-- .../RegistrationControllerTest.java | 11 ++-- 6 files changed, 82 insertions(+), 30 deletions(-) diff --git a/integration-tests/src/main/java/org/signal/integration/TestUser.java b/integration-tests/src/main/java/org/signal/integration/TestUser.java index 8cc8f36a5..045b03e35 100644 --- a/integration-tests/src/main/java/org/signal/integration/TestUser.java +++ b/integration-tests/src/main/java/org/signal/integration/TestUser.java @@ -14,7 +14,6 @@ import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import org.signal.libsignal.protocol.IdentityKey; @@ -130,7 +129,7 @@ public class TestUser { public AccountAttributes accountAttributes() { return new AccountAttributes(true, registrationId, pniRegistrationId, "".getBytes(StandardCharsets.UTF_8), "", true, - DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION) + DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES) .withUnidentifiedAccessKey(unidentifiedAccessKey) .withRecoveryPassword(registrationPassword); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index c0db6bc0f..b802d9316 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -408,6 +408,10 @@ public class DeviceController { } private static boolean isCapabilityDowngrade(final Account account, final Set capabilities) { + if (!capabilities.containsAll(DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES)) { + return true; + } + final Set requiredCapabilities = Arrays.stream(DeviceCapability.values()) // `ALWAYS_CAPABLE` capabilities are always assumed to be present, so we don't require callers to specify them .filter(capability -> capability.getAccountCapabilityMode() != DeviceCapability.AccountCapabilityMode.ALWAYS_CAPABLE) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java index 510707386..efd8c0a2f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java @@ -118,7 +118,7 @@ public class RegistrationController { if (!(registrationRequest.accountAttributes().getCapabilities() != null ? registrationRequest.accountAttributes().getCapabilities() - : Collections.emptySet()).containsAll(DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION)) { + : Collections.emptySet()).containsAll(DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES)) { throw new WebApplicationException("Missing required device capability", 499); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceCapability.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceCapability.java index f9fdf8e3e..df08bdcb9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceCapability.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceCapability.java @@ -36,42 +36,41 @@ public enum DeviceCapability { ALWAYS_CAPABLE, } - public static final Set CAPABILITIES_REQUIRED_FOR_REGISTRATION = + public static final Set CAPABILITIES_REQUIRED_FOR_NEW_DEVICES = Arrays.stream(DeviceCapability.values()) - .filter(DeviceCapability::requireForRegistration) + .filter(DeviceCapability::requireForNewDevices) .collect(Collectors.toSet()); private final String name; private final AccountCapabilityMode accountCapabilityMode; private final boolean preventDowngrade; private final boolean includeInProfile; - private final boolean requireForRegistration; + private final boolean requireForNewDevices; /** * Create a DeviceCapability * - * @param name The name of the device capability that clients will see - * @param accountCapabilityMode How to combine the constituent device's capabilities in the account to an overall - * account capability - * @param preventDowngrade If true, don't let linked devices join that don't have a device capability if the - * overall account has the capability. Most of the time this should only be used in - * conjunction with AccountCapabilityMode.ALL_DEVICES. - * @param includeInProfile Whether to return this capability on the account's profile. If false, the capability - * is only visible to the server. - * @param requireForRegistration If true, prevent account creation if the account's initial device does not have this - * capability + * @param name The name of the device capability that clients will see + * @param accountCapabilityMode How to combine the constituent device's capabilities in the account to an overall + * account capability + * @param preventDowngrade If true, don't let linked devices join that don't have a device capability if the + * overall account has the capability. Most of the time this should only be used in + * conjunction with AccountCapabilityMode.ALL_DEVICES. + * @param includeInProfile Whether to return this capability on the account's profile. If false, the capability + * is only visible to the server. + * @param requireForNewDevices If true, prevent device creation if the new device does not have this capability */ DeviceCapability(final String name, final AccountCapabilityMode accountCapabilityMode, final boolean preventDowngrade, final boolean includeInProfile, - final boolean requireForRegistration) { + final boolean requireForNewDevices) { this.name = name; this.accountCapabilityMode = accountCapabilityMode; this.preventDowngrade = preventDowngrade; this.includeInProfile = includeInProfile; - this.requireForRegistration = requireForRegistration; + this.requireForNewDevices = requireForNewDevices; } public String getName() { @@ -90,8 +89,8 @@ public enum DeviceCapability { return includeInProfile; } - public boolean requireForRegistration() { - return requireForRegistration; + public boolean requireForNewDevices() { + return requireForNewDevices; } public static Optional forName(final String name) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 8d9314d13..0208c66cf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -32,6 +32,7 @@ import jakarta.ws.rs.core.Response; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Arrays; import java.util.Base64; import java.util.EnumSet; import java.util.HashMap; @@ -40,6 +41,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; @@ -255,7 +257,7 @@ class DeviceControllerTest { when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null, - null, true, Set.of()); + null, true, DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES); final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token", accountAttributes, @@ -296,7 +298,7 @@ class DeviceControllerTest { } @CartesianTest - void deviceDowngrade(@CartesianTest.Enum final DeviceCapability capability, + void deviceDowngrade(@CartesianTest.Enum(mode = CartesianTest.Enum.Mode.EXCLUDE, names = "SPARSE_POST_QUANTUM_RATCHET") final DeviceCapability capability, @CartesianTest.Values(booleans = {true, false}) final boolean accountHasCapability, @CartesianTest.Values(booleans = {true, false}) final boolean requestHasCapability) throws LinkDeviceTokenAlreadyUsedException { @@ -354,6 +356,55 @@ class DeviceControllerTest { } } + @Test + void missingRequiredCapability() throws LinkDeviceTokenAlreadyUsedException { + + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); + when(accountsManager.addDevice(any(), any(), any())) + .thenReturn(new Pair<>(mock(Account.class), mock(Device.class))); + + final Device primaryDevice = mock(Device.class); + when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); + when(account.getDevices()).thenReturn(List.of(primaryDevice)); + + final ECSignedPreKey aciSignedPreKey; + final ECSignedPreKey pniSignedPreKey; + final KEMSignedPreKey aciPqLastResortPreKey; + final KEMSignedPreKey pniPqLastResortPreKey; + + final ECKeyPair aciIdentityKeyPair = ECKeyPair.generate(); + final ECKeyPair pniIdentityKeyPair = ECKeyPair.generate(); + + aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair); + pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair); + aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); + pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); + + when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); + + when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); + + when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID)); + + final Set requestCapabilities = Arrays.stream(DeviceCapability.values()) + .filter(capability -> !capability.requireForNewDevices()) + .collect(Collectors.toSet()); + + final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token", + new AccountAttributes(false, 1234, 5678, null, null, true, requestCapabilities), + new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id")))); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/link") + .request() + .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) + .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(409, response.getStatus()); + } + } + @Test void linkDeviceAtomicBadCredentials() { when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); @@ -422,7 +473,7 @@ class DeviceControllerTest { when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); final AccountAttributes accountAttributes = new AccountAttributes(true, 1234, 5678, null, - null, true, Set.of()); + null, true, DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES); final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token", accountAttributes, @@ -760,7 +811,7 @@ class DeviceControllerTest { when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token", - new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, Set.of()), + new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES), new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn")), Optional.empty())); try (final Response response = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index f12e3d26f..4cd9b7135 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -87,7 +87,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; -import org.whispersystems.textsecuregcm.util.TestRandomUtil; @ExtendWith(DropwizardExtensionsSupport.class) class RegistrationControllerTest { @@ -543,11 +542,11 @@ class RegistrationControllerTest { final AccountAttributes fetchesMessagesAccountAttributes = new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, - DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION); + DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES); final AccountAttributes pushAccountAttributes = new AccountAttributes(false, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, - DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION); + DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES); return List.of( Arguments.argumentSet("\"Fetches messages\" is true, but an APNs token is provided", @@ -634,7 +633,7 @@ class RegistrationControllerTest { final AccountAttributes accountAttributes = new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, - DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION); + DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES); return List.of( Arguments.argumentSet("Signed PNI EC pre-key is missing", @@ -859,7 +858,7 @@ class RegistrationControllerTest { final int registrationId = 1; final int pniRegistrationId = 2; - final Set deviceCapabilities = DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION; + final Set deviceCapabilities = DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES; final AccountAttributes fetchesMessagesAccountAttributes = new AccountAttributes(true, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, deviceCapabilities); @@ -1012,7 +1011,7 @@ class RegistrationControllerTest { final boolean skipDeviceTransfer, final int registrationId, final int pniRegistrationId) { - return requestToJson(request(sessionId, recoveryPassword, skipDeviceTransfer, registrationId, pniRegistrationId, DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION)); + return requestToJson(request(sessionId, recoveryPassword, skipDeviceTransfer, registrationId, pniRegistrationId, DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES)); } /**