mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-23 21:28:06 +01:00
Internalize destination device list/registration ID checks in MessageSender
This commit is contained in:
@@ -22,6 +22,9 @@ import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.OptionalInt;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
@@ -30,12 +33,22 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||
import org.signal.libsignal.protocol.InvalidVersionException;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
|
||||
|
||||
class MessageSenderTest {
|
||||
|
||||
@@ -60,7 +73,9 @@ class MessageSenderTest {
|
||||
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
@@ -71,7 +86,11 @@ class MessageSenderTest {
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
|
||||
if (hasPushToken) {
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
@@ -82,7 +101,10 @@ class MessageSenderTest {
|
||||
|
||||
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message)));
|
||||
assertDoesNotThrow(() -> messageSender.sendMessages(account,
|
||||
serviceIdentifier,
|
||||
Map.of(device.getId(), message),
|
||||
Map.of(device.getId(), registrationId)));
|
||||
|
||||
final MessageProtos.Envelope expectedMessage = ephemeral
|
||||
? message.toBuilder().setEphemeral(true).build()
|
||||
@@ -97,23 +119,61 @@ class MessageSenderTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void sendMessageMismatchedDevices() {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder().build();
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
final MismatchedDevicesException mismatchedDevicesException =
|
||||
assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account,
|
||||
serviceIdentifier,
|
||||
Map.of(device.getId(), message),
|
||||
Map.of(device.getId(), registrationId + 1)));
|
||||
|
||||
assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)),
|
||||
mismatchedDevicesException.getMismatchedDevices());
|
||||
}
|
||||
|
||||
@CartesianTest
|
||||
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken)
|
||||
throws NotPushRegisteredException, InvalidMessageException, InvalidVersionException {
|
||||
|
||||
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
if (hasPushToken) {
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
@@ -125,12 +185,19 @@ class MessageSenderTest {
|
||||
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class),
|
||||
Collections.emptyMap(),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
ephemeral,
|
||||
urgent)
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
|
||||
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
|
||||
|
||||
final SealedSenderMultiRecipientMessage.Recipient recipient =
|
||||
multiRecipientMessage.getRecipients().values().iterator().next();
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
|
||||
Map.of(recipient, account),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
ephemeral,
|
||||
urgent)
|
||||
.join());
|
||||
|
||||
if (expectPushNotificationAttempt) {
|
||||
@@ -140,6 +207,49 @@ class MessageSenderTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void sendMultiRecipientMessageMismatchedDevices() throws InvalidMessageException, InvalidVersionException {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
|
||||
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId + 1, new byte[48]))));
|
||||
|
||||
final SealedSenderMultiRecipientMessage.Recipient recipient =
|
||||
multiRecipientMessage.getRecipients().values().iterator().next();
|
||||
|
||||
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, true))));
|
||||
|
||||
final MultiRecipientMismatchedDevicesException mismatchedDevicesException =
|
||||
assertThrows(MultiRecipientMismatchedDevicesException.class,
|
||||
() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
|
||||
Map.of(recipient, account),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
false,
|
||||
true)
|
||||
.join());
|
||||
|
||||
assertEquals(Map.of(serviceIdentifier, new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId))),
|
||||
mismatchedDevicesException.getMismatchedDevicesByServiceIdentifier());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getDeliveryChannelName(final Device device, final String expectedChannelName) {
|
||||
@@ -183,4 +293,87 @@ class MessageSenderTest {
|
||||
assertDoesNotThrow(() ->
|
||||
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getMismatchedDevices(final Account account,
|
||||
final ServiceIdentifier serviceIdentifier,
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
final byte excludedDeviceId,
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<MismatchedDevices> expectedMismatchedDevices) {
|
||||
|
||||
assertEquals(expectedMismatchedDevices,
|
||||
MessageSender.getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, excludedDeviceId));
|
||||
}
|
||||
|
||||
private static List<Arguments> getMismatchedDevices() {
|
||||
final byte primaryDeviceId = Device.PRIMARY_ID;
|
||||
final byte linkedDeviceId = primaryDeviceId + 1;
|
||||
final byte extraDeviceId = linkedDeviceId + 1;
|
||||
|
||||
final int primaryDeviceAciRegistrationId = 2;
|
||||
final int primaryDevicePniRegistrationId = 3;
|
||||
final int linkedDeviceAciRegistrationId = 5;
|
||||
final int linkedDevicePniRegistrationId = 7;
|
||||
|
||||
final Device primaryDevice = mock(Device.class);
|
||||
when(primaryDevice.getId()).thenReturn(primaryDeviceId);
|
||||
when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId);
|
||||
when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(primaryDevicePniRegistrationId));
|
||||
|
||||
final Device linkedDevice = mock(Device.class);
|
||||
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
|
||||
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId);
|
||||
when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(linkedDevicePniRegistrationId));
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
|
||||
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
|
||||
when(account.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice));
|
||||
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
|
||||
|
||||
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
|
||||
final PniServiceIdentifier pniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
|
||||
|
||||
return List.of(
|
||||
Arguments.argumentSet("Complete device list for ACI, no devices excluded",
|
||||
account,
|
||||
aciServiceIdentifier,
|
||||
Map.of(
|
||||
primaryDeviceId, primaryDeviceAciRegistrationId,
|
||||
linkedDeviceId, linkedDeviceAciRegistrationId
|
||||
),
|
||||
MessageSender.NO_EXCLUDED_DEVICE_ID,
|
||||
Optional.empty()),
|
||||
|
||||
Arguments.argumentSet("Complete device list for PNI, no devices excluded",
|
||||
account,
|
||||
pniServiceIdentifier,
|
||||
Map.of(
|
||||
primaryDeviceId, primaryDevicePniRegistrationId,
|
||||
linkedDeviceId, linkedDevicePniRegistrationId
|
||||
),
|
||||
MessageSender.NO_EXCLUDED_DEVICE_ID,
|
||||
Optional.empty()),
|
||||
|
||||
Arguments.argumentSet("Complete device list, device excluded",
|
||||
account,
|
||||
aciServiceIdentifier,
|
||||
Map.of(
|
||||
linkedDeviceId, linkedDeviceAciRegistrationId
|
||||
),
|
||||
primaryDeviceId,
|
||||
Optional.empty()),
|
||||
|
||||
Arguments.argumentSet("Mismatched devices",
|
||||
account,
|
||||
aciServiceIdentifier,
|
||||
Map.of(
|
||||
linkedDeviceId, linkedDeviceAciRegistrationId + 1,
|
||||
extraDeviceId, 17
|
||||
),
|
||||
MessageSender.NO_EXCLUDED_DEVICE_ID,
|
||||
Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId))))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user