diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java index e97a92cb2..bb6bac3d6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java @@ -8,22 +8,23 @@ package org.whispersystems.textsecuregcm.auth; import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.lifecycle.Managed; import io.lettuce.core.pubsub.RedisPubSubAdapter; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; import java.nio.charset.StandardCharsets; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.UUID; import java.util.concurrent.CompletionStage; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executor; import javax.annotation.Nullable; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.UUIDUtil; @@ -96,12 +97,13 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter requestDisconnection(final UUID accountIdentifier) { - return requestDisconnection(accountIdentifier, Collections.emptyList()); + public CompletionStage requestDisconnection(final Account account) { + return requestDisconnection(account.getIdentifier(IdentityType.ACI), + account.getDevices().stream().map(Device::getId).toList()); } /** @@ -135,8 +137,7 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter 0 - ? disconnectionRequest.getDeviceIdsList().stream() + deviceIds = disconnectionRequest.getDeviceIdsList().stream() .map(deviceIdInt -> { if (deviceIdInt == null || deviceIdInt < Device.PRIMARY_ID || deviceIdInt > Byte.MAX_VALUE) { throw new IllegalArgumentException("Invalid device ID: " + deviceIdInt); @@ -144,8 +145,7 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter implemen keysManager.deleteSingleUsePreKeys(pni), messagesManager.clear(aci), profilesManager.deleteAll(aci, false)) - .thenCompose(ignored -> disconnectionRequestManager.requestDisconnection(aci)) + .thenCompose(ignored -> disconnectionRequestManager.requestDisconnection(e.getExistingAccount())) .thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), account, additionalWriteItems)) .thenCompose(ignored -> { // We should have cleared all messages before overwriting the old account, but more may have arrived @@ -1228,7 +1228,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen registrationRecoveryPasswordsManager.remove(account.getIdentifier(IdentityType.PNI))) .thenCompose(ignored -> accounts.delete(account.getUuid(), additionalWriteItems)) .thenCompose(ignored -> redisDeleteAsync(account)) - .thenRun(() -> disconnectionRequestManager.requestDisconnection(account.getUuid())); + .thenRun(() -> disconnectionRequestManager.requestDisconnection(account)); } private String getAccountMapKey(String key) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java index 963018524..ccb579530 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java @@ -6,9 +6,13 @@ package org.whispersystems.textsecuregcm.auth; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.util.Collection; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.UUID; import java.util.concurrent.CountDownLatch; import org.junit.jupiter.api.AfterEach; @@ -16,7 +20,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.redis.RedisServerExtension; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; @Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) @@ -84,16 +90,26 @@ class DisconnectionRequestManagerTest { @Test void requestDisconnectionAllDevices() throws InterruptedException { + final Device primaryDevice = mock(Device.class); + when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); + + final Device linkedDevice = mock(Device.class); + when(linkedDevice.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1)); + final UUID accountIdentifier = UUID.randomUUID(); + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice)); + final DisconnectionRequestTestListener listener = new DisconnectionRequestTestListener(); disconnectionRequestManager.addListener(listener); - disconnectionRequestManager.requestDisconnection(accountIdentifier).toCompletableFuture().join(); + disconnectionRequestManager.requestDisconnection(account).toCompletableFuture().join(); listener.waitForRequest(); assertEquals(accountIdentifier, listener.getAccountIdentifier()); - assertEquals(Device.ALL_POSSIBLE_DEVICE_IDS, listener.getDeviceIds()); + assertEquals(List.of(Device.PRIMARY_ID, (byte) (Device.PRIMARY_ID + 1)), listener.getDeviceIds()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java index f54e0f98b..9048a7212 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -314,7 +315,7 @@ public class AccountCreationDeletionIntegrationTest { final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair); final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair); - final Account originalAccount = accountsManager.create(number, + final Account existingAccount = accountsManager.create(number, new AccountAttributes(true, 1, 1, "name".getBytes(StandardCharsets.UTF_8), "registration-lock", false, Set.of()), Collections.emptyList(), new IdentityKey(aciKeyPair.getPublicKey()), @@ -334,7 +335,7 @@ public class AccountCreationDeletionIntegrationTest { pniPqLastResortPreKey), null); - existingAccountUuid = originalAccount.getUuid(); + existingAccountUuid = existingAccount.getUuid(); } final String password = RandomStringUtils.secure().nextAlphanumeric(16); @@ -417,7 +418,8 @@ public class AccountCreationDeletionIntegrationTest { assertEquals(existingAccountUuid, reregisteredAccount.getUuid()); - verify(disconnectionRequestManager).requestDisconnection(existingAccountUuid); + verify(disconnectionRequestManager).requestDisconnection(argThat(account -> + account.getIdentifier(IdentityType.ACI).equals(existingAccountUuid) && account != reregisteredAccount)); } @Test @@ -492,7 +494,7 @@ public class AccountCreationDeletionIntegrationTest { assertFalse(keysManager.getLastResort(account.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertFalse(clientPublicKeysManager.findPublicKey(account.getUuid(), Device.PRIMARY_ID).join().isPresent()); - verify(disconnectionRequestManager).requestDisconnection(aci); + verify(disconnectionRequestManager).requestDisconnection(account); } @SuppressWarnings("OptionalUsedAsFieldOrParameterType") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index 9284edce3..0531d28a7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -350,7 +351,8 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow()); - verify(disconnectionRequestManager).requestDisconnection(existingAccountUuid); + verify(disconnectionRequestManager).requestDisconnection(argThat(disconnectedAccount -> + disconnectedAccount.getIdentifier(IdentityType.ACI).equals(existingAccountUuid) && disconnectedAccount != account)); assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 7ce6ea1b0..be7f8ef9f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -867,7 +867,8 @@ class AccountsManagerTest { verify(keysManager, times(2)).deleteSingleUsePreKeys(phoneNumberIdentifiersByE164.get(e164)); verify(messagesManager, times(2)).clear(existingUuid); verify(profilesManager, times(2)).deleteAll(existingUuid, false); - verify(disconnectionRequestManager).requestDisconnection(existingUuid); + verify(disconnectionRequestManager).requestDisconnection(argThat(account -> + account.getIdentifier(IdentityType.ACI).equals(existingUuid) && account != reregisteredAccount)); } @Test