Don't allow linked devices that are missing capabilities required at registration time

This commit is contained in:
Jon Chambers
2026-03-06 12:42:04 -05:00
committed by Jon Chambers
parent 46bfc12869
commit 575280da74
6 changed files with 82 additions and 30 deletions

View File

@@ -14,7 +14,6 @@ import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.signal.libsignal.protocol.IdentityKey;
@@ -130,7 +129,7 @@ public class TestUser {
public AccountAttributes accountAttributes() {
return new AccountAttributes(true, registrationId, pniRegistrationId, "".getBytes(StandardCharsets.UTF_8), "", true,
DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION)
DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES)
.withUnidentifiedAccessKey(unidentifiedAccessKey)
.withRecoveryPassword(registrationPassword);
}

View File

@@ -408,6 +408,10 @@ public class DeviceController {
}
private static boolean isCapabilityDowngrade(final Account account, final Set<DeviceCapability> capabilities) {
if (!capabilities.containsAll(DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES)) {
return true;
}
final Set<DeviceCapability> requiredCapabilities = Arrays.stream(DeviceCapability.values())
// `ALWAYS_CAPABLE` capabilities are always assumed to be present, so we don't require callers to specify them
.filter(capability -> capability.getAccountCapabilityMode() != DeviceCapability.AccountCapabilityMode.ALWAYS_CAPABLE)

View File

@@ -118,7 +118,7 @@ public class RegistrationController {
if (!(registrationRequest.accountAttributes().getCapabilities() != null
? registrationRequest.accountAttributes().getCapabilities()
: Collections.<DeviceCapability>emptySet()).containsAll(DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION)) {
: Collections.<DeviceCapability>emptySet()).containsAll(DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES)) {
throw new WebApplicationException("Missing required device capability", 499);
}

View File

@@ -36,42 +36,41 @@ public enum DeviceCapability {
ALWAYS_CAPABLE,
}
public static final Set<DeviceCapability> CAPABILITIES_REQUIRED_FOR_REGISTRATION =
public static final Set<DeviceCapability> CAPABILITIES_REQUIRED_FOR_NEW_DEVICES =
Arrays.stream(DeviceCapability.values())
.filter(DeviceCapability::requireForRegistration)
.filter(DeviceCapability::requireForNewDevices)
.collect(Collectors.toSet());
private final String name;
private final AccountCapabilityMode accountCapabilityMode;
private final boolean preventDowngrade;
private final boolean includeInProfile;
private final boolean requireForRegistration;
private final boolean requireForNewDevices;
/**
* Create a DeviceCapability
*
* @param name The name of the device capability that clients will see
* @param accountCapabilityMode How to combine the constituent device's capabilities in the account to an overall
* account capability
* @param preventDowngrade If true, don't let linked devices join that don't have a device capability if the
* overall account has the capability. Most of the time this should only be used in
* conjunction with AccountCapabilityMode.ALL_DEVICES.
* @param includeInProfile Whether to return this capability on the account's profile. If false, the capability
* is only visible to the server.
* @param requireForRegistration If true, prevent account creation if the account's initial device does not have this
* capability
* @param name The name of the device capability that clients will see
* @param accountCapabilityMode How to combine the constituent device's capabilities in the account to an overall
* account capability
* @param preventDowngrade If true, don't let linked devices join that don't have a device capability if the
* overall account has the capability. Most of the time this should only be used in
* conjunction with AccountCapabilityMode.ALL_DEVICES.
* @param includeInProfile Whether to return this capability on the account's profile. If false, the capability
* is only visible to the server.
* @param requireForNewDevices If true, prevent device creation if the new device does not have this capability
*/
DeviceCapability(final String name,
final AccountCapabilityMode accountCapabilityMode,
final boolean preventDowngrade,
final boolean includeInProfile,
final boolean requireForRegistration) {
final boolean requireForNewDevices) {
this.name = name;
this.accountCapabilityMode = accountCapabilityMode;
this.preventDowngrade = preventDowngrade;
this.includeInProfile = includeInProfile;
this.requireForRegistration = requireForRegistration;
this.requireForNewDevices = requireForNewDevices;
}
public String getName() {
@@ -90,8 +89,8 @@ public enum DeviceCapability {
return includeInProfile;
}
public boolean requireForRegistration() {
return requireForRegistration;
public boolean requireForNewDevices() {
return requireForNewDevices;
}
public static Optional<DeviceCapability> forName(final String name) {

View File

@@ -32,6 +32,7 @@ import jakarta.ws.rs.core.Response;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
import java.util.EnumSet;
import java.util.HashMap;
@@ -40,6 +41,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.RandomStringUtils;
@@ -255,7 +257,7 @@ class DeviceControllerTest {
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null,
null, true, Set.of());
null, true, DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES);
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
accountAttributes,
@@ -296,7 +298,7 @@ class DeviceControllerTest {
}
@CartesianTest
void deviceDowngrade(@CartesianTest.Enum final DeviceCapability capability,
void deviceDowngrade(@CartesianTest.Enum(mode = CartesianTest.Enum.Mode.EXCLUDE, names = "SPARSE_POST_QUANTUM_RATCHET") final DeviceCapability capability,
@CartesianTest.Values(booleans = {true, false}) final boolean accountHasCapability,
@CartesianTest.Values(booleans = {true, false}) final boolean requestHasCapability)
throws LinkDeviceTokenAlreadyUsedException {
@@ -354,6 +356,55 @@ class DeviceControllerTest {
}
}
@Test
void missingRequiredCapability() throws LinkDeviceTokenAlreadyUsedException {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(new Pair<>(mock(Account.class), mock(Device.class)));
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(account.getDevices()).thenReturn(List.of(primaryDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = ECKeyPair.generate();
final ECKeyPair pniIdentityKeyPair = ECKeyPair.generate();
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final Set<DeviceCapability> requestCapabilities = Arrays.stream(DeviceCapability.values())
.filter(capability -> !capability.requireForNewDevices())
.collect(Collectors.toSet());
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, requestCapabilities),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(409, response.getStatus());
}
}
@Test
void linkDeviceAtomicBadCredentials() {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
@@ -422,7 +473,7 @@ class DeviceControllerTest {
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(true, 1234, 5678, null,
null, true, Set.of());
null, true, DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES);
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
accountAttributes,
@@ -760,7 +811,7 @@ class DeviceControllerTest {
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, Set.of()),
new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn")), Optional.empty()));
try (final Response response = resources.getJerseyTest()

View File

@@ -87,7 +87,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ExtendWith(DropwizardExtensionsSupport.class)
class RegistrationControllerTest {
@@ -543,11 +542,11 @@ class RegistrationControllerTest {
final AccountAttributes fetchesMessagesAccountAttributes =
new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true,
DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION);
DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES);
final AccountAttributes pushAccountAttributes =
new AccountAttributes(false, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true,
DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION);
DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES);
return List.of(
Arguments.argumentSet("\"Fetches messages\" is true, but an APNs token is provided",
@@ -634,7 +633,7 @@ class RegistrationControllerTest {
final AccountAttributes accountAttributes =
new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true,
DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION);
DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES);
return List.of(
Arguments.argumentSet("Signed PNI EC pre-key is missing",
@@ -859,7 +858,7 @@ class RegistrationControllerTest {
final int registrationId = 1;
final int pniRegistrationId = 2;
final Set<DeviceCapability> deviceCapabilities = DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION;
final Set<DeviceCapability> deviceCapabilities = DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES;
final AccountAttributes fetchesMessagesAccountAttributes =
new AccountAttributes(true, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, deviceCapabilities);
@@ -1012,7 +1011,7 @@ class RegistrationControllerTest {
final boolean skipDeviceTransfer,
final int registrationId,
final int pniRegistrationId) {
return requestToJson(request(sessionId, recoveryPassword, skipDeviceTransfer, registrationId, pniRegistrationId, DeviceCapability.CAPABILITIES_REQUIRED_FOR_REGISTRATION));
return requestToJson(request(sessionId, recoveryPassword, skipDeviceTransfer, registrationId, pniRegistrationId, DeviceCapability.CAPABILITIES_REQUIRED_FOR_NEW_DEVICES));
}
/**