Fix registration ID map construction when changing numbers

This commit is contained in:
Jon Chambers
2025-04-15 14:57:28 -04:00
committed by GitHub
parent 2f2ae7cec5
commit 3c40e72d27
3 changed files with 54 additions and 28 deletions

View File

@@ -9,13 +9,16 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
@@ -35,8 +38,10 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
public class ChangeNumberManagerTest {
private AccountsManager accountsManager;
@@ -45,11 +50,13 @@ public class ChangeNumberManagerTest {
private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount;
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
@BeforeEach
void setUp() throws Exception {
accountsManager = mock(AccountsManager.class);
messageSender = mock(MessageSender.class);
changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, CLOCK);
updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
@@ -132,45 +139,59 @@ public class ChangeNumberManagerTest {
when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(primaryDevice.getRegistrationId()).thenReturn(7);
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final Device linkedDevice = mock(Device.class);
final byte linkedDeviceId = Device.PRIMARY_ID + 1;
final int linkedDeviceRegistrationId = 17;
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
linkedDeviceId, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, linkedDeviceId, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.type()).thenReturn(1);
when(msg.destinationDeviceId()).thenReturn(linkedDeviceId);
when(msg.destinationRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(msg.type()))
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setDestinationServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setContent(ByteString.copyFrom(msg.content()))
.setSourceServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(updatedPhoneNumberIdentifiersByAccount.get(account).toString())
.setUrgent(true)
.setEphemeral(false)
.build();
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
verify(messageSender).sendMessages(argThat(a -> a.getUuid().equals(aci)),
eq(new AciServiceIdentifier(aci)),
eq(Map.of(linkedDeviceId, expectedEnvelope)),
eq(Map.of(linkedDeviceId, linkedDeviceRegistrationId)),
eq(Optional.of(Device.PRIMARY_ID)),
any());
}
@Test
void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final String originalE164 = "+18005551234";