diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 899d76506..9febefcb8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -294,6 +294,7 @@ import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredBackupsCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredLinkedDevicesCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredUsernameHoldsCommand; +import org.whispersystems.textsecuregcm.workers.RemoveNonSpqrLinkedDevicesCommand; import org.whispersystems.textsecuregcm.workers.RemoveOrphanedPreKeyPagesCommand; import org.whispersystems.textsecuregcm.workers.ScheduledApnPushNotificationSenderServiceCommand; import org.whispersystems.textsecuregcm.workers.ServerVersionCommand; @@ -356,6 +357,7 @@ public class WhisperServerService extends Application accounts) { + final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); + final boolean dryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT); + + final AccountsManager accountsManager = getCommandDependencies().accountsManager(); + + final Counter removeDeviceCounterName = + Metrics.counter(REMOVE_DEVICE_COUNTER_NAME, "dryRun", String.valueOf(dryRun)); + + accounts + .flatMap(account -> Flux.fromIterable(account.getDevices()) + .filter(device -> !device.isPrimary()) + .filter(device -> !device.hasCapability(DeviceCapability.SPARSE_POST_QUANTUM_RATCHET)) + .map(device -> Tuples.of(account, device.getId()))) + .flatMap(accountAndDeviceId -> { + final Mono removeDeviceMono = dryRun + ? Mono.empty() + : Mono.fromRunnable(() -> accountsManager.removeDevice(accountAndDeviceId.getT1(), accountAndDeviceId.getT2())) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))) + .onErrorResume(throwable -> { + logger.warn("Failed to remove device: {}:{}", + accountAndDeviceId.getT1().getIdentifier(IdentityType.ACI), + accountAndDeviceId.getT2(), + throwable); + + return Mono.empty(); + }) + .then(); + + return removeDeviceMono + .doOnSuccess(_ -> removeDeviceCounterName.increment()); + }, maxConcurrency) + .then() + .block(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveNonSpqrLinkedDevicesCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveNonSpqrLinkedDevicesCommandTest.java new file mode 100644 index 000000000..dd8cce385 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveNonSpqrLinkedDevicesCommandTest.java @@ -0,0 +1,98 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; +import net.sourceforge.argparse4j.inf.Namespace; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.DeviceCapability; +import reactor.core.publisher.Flux; + +class RemoveNonSpqrLinkedDevicesCommandTest { + + private static class TestRemoveNonSpqrLinkedDevicesCommand extends RemoveNonSpqrLinkedDevicesCommand { + + private final CommandDependencies commandDependencies; + private final Namespace namespace; + + public TestRemoveNonSpqrLinkedDevicesCommand(final boolean isDryRun) { + + commandDependencies = mock(CommandDependencies.class); + when(commandDependencies.accountsManager()).thenReturn(mock(AccountsManager.class)); + + namespace = new Namespace(Map.of( + RemoveNonSpqrLinkedDevicesCommand.DRY_RUN_ARGUMENT, isDryRun, + RemoveNonSpqrLinkedDevicesCommand.MAX_CONCURRENCY_ARGUMENT, 16)); + } + + @Override + protected CommandDependencies getCommandDependencies() { + return commandDependencies; + } + + @Override + protected Namespace getNamespace() { + return namespace; + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void crawlAccounts(final boolean dryRun) { + final Device primaryDeviceWithSpqr = buildMockDevice(true, true); + final Device primaryDeviceWithoutSpqr = buildMockDevice(true, false); + final Device linkedDeviceWithSpqr = buildMockDevice(false, true); + final Device linkedDeviceWithoutSpqr = buildMockDevice(false, false); + + final Account accountWithNonSpqrPrimary = mock(Account.class); + when(accountWithNonSpqrPrimary.getDevices()) + .thenReturn(List.of(primaryDeviceWithoutSpqr)); + + final Account accountWithSpqrLinkedDevice = mock(Account.class); + when(accountWithSpqrLinkedDevice.getDevices()) + .thenReturn(List.of(primaryDeviceWithSpqr, linkedDeviceWithSpqr)); + + final Account accountWithNonSpqrLinkedDevice = mock(Account.class); + when(accountWithNonSpqrLinkedDevice.getDevices()) + .thenReturn(List.of(primaryDeviceWithSpqr, linkedDeviceWithoutSpqr)); + + final RemoveNonSpqrLinkedDevicesCommand removeNonSpqrLinkedDevicesCommand = + new TestRemoveNonSpqrLinkedDevicesCommand(dryRun); + + removeNonSpqrLinkedDevicesCommand.crawlAccounts(Flux.just( + accountWithNonSpqrPrimary, accountWithSpqrLinkedDevice, accountWithNonSpqrLinkedDevice)); + + final AccountsManager accountsManager = + removeNonSpqrLinkedDevicesCommand.getCommandDependencies().accountsManager(); + + if (dryRun) { + verifyNoInteractions(accountsManager); + } else { + verify(accountsManager).removeDevice(accountWithNonSpqrLinkedDevice, linkedDeviceWithoutSpqr.getId()); + verifyNoMoreInteractions(accountsManager); + } + } + + private Device buildMockDevice(final boolean isPrimary, final boolean supportsSpqr) { + final Device device = mock(Device.class); + when(device.isPrimary()).thenReturn(isPrimary); + when(device.getId()).thenReturn(isPrimary ? Device.PRIMARY_ID : Device.PRIMARY_ID + 1); + when(device.hasCapability(DeviceCapability.SPARSE_POST_QUANTUM_RATCHET)).thenReturn(supportsSpqr); + + return device; + } +}