Validate intra-account messages before applying number changes

This commit is contained in:
Jon Chambers
2025-07-17 11:34:50 -04:00
committed by GitHub
parent 50bc6b2c62
commit 4ccd39fd55
8 changed files with 404 additions and 191 deletions

View File

@@ -18,6 +18,8 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.micrometer.core.instrument.Tag;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -25,8 +27,10 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
@@ -38,31 +42,30 @@ import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
class MessageSenderTest {
private MessagesManager messagesManager;
private PushNotificationManager pushNotificationManager;
private MessageSender messageSender;
private ExperimentEnrollmentManager experimentEnrollmentManager;
@BeforeEach
void setUp() {
messagesManager = mock(MessagesManager.class);
pushNotificationManager = mock(PushNotificationManager.class);
experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager);
messageSender = new MessageSender(messagesManager, pushNotificationManager);
}
@@ -258,13 +261,183 @@ class MessageSenderTest {
mismatchedDevicesException.getMismatchedDevicesByServiceIdentifier());
}
@ParameterizedTest
@MethodSource
void validateIndividualMessageBundle(final Account destination,
final ServiceIdentifier destinationIdentifier,
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Byte> syncMessageSenderDeviceId,
@Nullable final Class<? extends Exception> expectedExceptionClass) {
final Executable validateIndividualMessageBundle = () -> MessageSender.validateIndividualMessageBundle(destination,
destinationIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId,
syncMessageSenderDeviceId,
"Signal/Test");
if (expectedExceptionClass != null) {
assertThrows(expectedExceptionClass, validateIndividualMessageBundle);
} else {
assertDoesNotThrow(validateIndividualMessageBundle);
}
}
private static List<Arguments> validateIndividualMessageBundle() {
final ServiceIdentifier destinationIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final byte primaryDeviceId = Device.PRIMARY_ID;
final byte linkedDeviceId = primaryDeviceId + 1;
final int primaryDeviceRegistrationId = 17;
final int linkedDeviceRegistrationId = primaryDeviceRegistrationId + 1;
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(primaryDeviceId);
when(primaryDevice.getRegistrationId(IdentityType.ACI)).thenReturn(primaryDeviceRegistrationId);
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(linkedDeviceRegistrationId);
final Account destination = mock(Account.class);
when(destination.isIdentifiedBy(any())).thenReturn(false);
when(destination.isIdentifiedBy(destinationIdentifier)).thenReturn(true);
when(destination.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
when(destination.getDevice(anyByte())).thenReturn(Optional.empty());
when(destination.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice));
when(destination.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
return List.of(
Arguments.argumentSet("Valid",
destination,
destinationIdentifier,
Map.of(
primaryDeviceId, generateEnvelope(null, 16),
linkedDeviceId, generateEnvelope(null, 16)),
Map.of(
primaryDeviceId, primaryDeviceRegistrationId,
linkedDeviceId, linkedDeviceRegistrationId),
Optional.empty(),
null),
Arguments.argumentSet("Mismatched service ID",
destination,
new AciServiceIdentifier(UUID.randomUUID()),
Map.of(
primaryDeviceId, generateEnvelope(null, 16),
linkedDeviceId, generateEnvelope(null, 16)),
Map.of(
primaryDeviceId, primaryDeviceRegistrationId,
linkedDeviceId, linkedDeviceRegistrationId),
Optional.empty(),
IllegalArgumentException.class),
Arguments.argumentSet("Sync message without source on all messages",
destination,
destinationIdentifier,
Map.of(linkedDeviceId, generateEnvelope(null, 16)),
Map.of(linkedDeviceId, linkedDeviceRegistrationId),
Optional.of(primaryDevice),
IllegalArgumentException.class),
Arguments.argumentSet("Sync message to other account",
destination,
destinationIdentifier,
Map.of(linkedDeviceId, generateEnvelope(new AciServiceIdentifier(UUID.randomUUID()), 16)),
Map.of(linkedDeviceId, linkedDeviceRegistrationId),
Optional.of(primaryDevice),
IllegalArgumentException.class),
Arguments.argumentSet("Sync message to other account",
destination,
destinationIdentifier,
Map.of(linkedDeviceId, generateEnvelope(new AciServiceIdentifier(UUID.randomUUID()), 16)),
Map.of(linkedDeviceId, linkedDeviceRegistrationId),
Optional.of(primaryDevice),
IllegalArgumentException.class),
Arguments.argumentSet("Non-sync message addressed to sender",
destination,
destinationIdentifier,
Map.of(
primaryDeviceId, generateEnvelope(destinationIdentifier, 16),
linkedDeviceId, generateEnvelope(destinationIdentifier, 16)),
Map.of(
primaryDeviceId, primaryDeviceRegistrationId,
linkedDeviceId, linkedDeviceRegistrationId),
Optional.empty(),
IllegalArgumentException.class),
Arguments.argumentSet("Non-sync message addressed to sender",
destination,
destinationIdentifier,
Map.of(
primaryDeviceId, generateEnvelope(destinationIdentifier, 16),
linkedDeviceId, generateEnvelope(destinationIdentifier, 16)),
Map.of(
primaryDeviceId, primaryDeviceRegistrationId,
linkedDeviceId, linkedDeviceRegistrationId),
Optional.empty(),
IllegalArgumentException.class),
Arguments.argumentSet("Mismatched devices in message set",
destination,
destinationIdentifier,
Map.of(
primaryDeviceId, generateEnvelope(null, 16),
linkedDeviceId + 1, generateEnvelope(null, 16)),
Map.of(
primaryDeviceId, primaryDeviceRegistrationId,
linkedDeviceId + 1, linkedDeviceRegistrationId),
Optional.empty(),
MismatchedDevicesException.class),
Arguments.argumentSet("Mismatched registration IDs",
destination,
destinationIdentifier,
Map.of(
primaryDeviceId, generateEnvelope(null, 16),
linkedDeviceId, generateEnvelope(null, 16)),
Map.of(
primaryDeviceId, primaryDeviceRegistrationId,
linkedDeviceId, linkedDeviceRegistrationId + 1),
Optional.empty(),
MismatchedDevicesException.class),
Arguments.argumentSet("Oversized message",
destination,
destinationIdentifier,
Map.of(
primaryDeviceId, generateEnvelope(null, MessageSender.MAX_MESSAGE_SIZE + 1),
linkedDeviceId, generateEnvelope(null, MessageSender.MAX_MESSAGE_SIZE + 1)),
Map.of(
primaryDeviceId, primaryDeviceRegistrationId,
linkedDeviceId, linkedDeviceRegistrationId),
Optional.empty(),
MessageTooLargeException.class)
);
}
private static MessageProtos.Envelope generateEnvelope(@Nullable ServiceIdentifier sourceIdentifier, final int contentLength) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setContent(ByteString.copyFrom(TestRandomUtil.nextBytes(contentLength)));
if (sourceIdentifier != null) {
envelopeBuilder.setSourceServiceId(sourceIdentifier.toServiceIdentifierString());
}
return envelopeBuilder.build();
}
@Test
void validateContentLength() {
assertThrows(MessageTooLargeException.class, () ->
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE + 1, false, false, false, null));
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE + 1, false, false, false, Tag.of(UserAgentTagUtil.PLATFORM_TAG, "test")));
assertDoesNotThrow(() ->
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null));
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, Tag.of(UserAgentTagUtil.PLATFORM_TAG, "test")));
}
@ParameterizedTest