accept group send endorsements for multi-recipient sends

This commit is contained in:
Jonathan Klabunde Tomer
2024-04-10 16:51:09 -07:00
committed by GitHub
parent cdd2082b07
commit 2b652fe2a9
10 changed files with 164 additions and 147 deletions

View File

@@ -43,6 +43,7 @@ import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
@@ -80,6 +81,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.zkgroup.ServerPublicParams;
@@ -88,9 +90,11 @@ import org.signal.libsignal.zkgroup.groups.ClientZkGroupCipher;
import org.signal.libsignal.zkgroup.groups.GroupMasterKey;
import org.signal.libsignal.zkgroup.groups.GroupSecretParams;
import org.signal.libsignal.zkgroup.groups.UuidCiphertext;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredential;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialPresentation;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialResponse;
import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair;
import org.signal.libsignal.zkgroup.groupsend.GroupSendEndorsement;
import org.signal.libsignal.zkgroup.groupsend.GroupSendFullToken;
import org.signal.libsignal.zkgroup.groupsend.GroupSendEndorsementsResponse.ReceivedEndorsements;
import org.signal.libsignal.zkgroup.groupsend.GroupSendEndorsementsResponse;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@@ -134,6 +138,7 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.websocket.Stories;
import reactor.core.publisher.Mono;
@@ -193,6 +198,8 @@ class MessageControllerTest {
private static final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private static final ServerSecretParams serverSecretParams = ServerSecretParams.generate();
private static final TestClock clock = TestClock.now();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
@@ -204,7 +211,7 @@ class MessageControllerTest {
new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager,
messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor,
messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager,
serverSecretParams, SpamChecker.noop()))
serverSecretParams, SpamChecker.noop(), clock))
.build();
@BeforeEach
@@ -258,6 +265,8 @@ class MessageControllerTest {
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);
when(rateLimiter.validateAsync(any(UUID.class))).thenReturn(CompletableFuture.completedFuture(null));
clock.unpin();
}
private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId,
@@ -980,25 +989,11 @@ class MessageControllerTest {
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
}
private static void writeMultiPayloadExcludedRecipient(final ByteBuffer bb, final ServiceIdentifier id, final boolean useExplicitIdentifier) {
if (useExplicitIdentifier) {
bb.put(id.toFixedWidthByteArray());
} else {
bb.put(UUIDUtil.toBytes(id.uuid()));
}
bb.put((byte) 0);
}
private static InputStream initializeMultiPayload(final List<Recipient> recipients, final byte[] buffer, final boolean explicitIdentifiers) {
return initializeMultiPayload(recipients, List.of(), buffer, explicitIdentifiers, 39);
return initializeMultiPayload(recipients, buffer, explicitIdentifiers, 39);
}
private static InputStream initializeMultiPayload(final List<Recipient> recipients, final List<ServiceIdentifier> excludedRecipients, final byte[] buffer, final boolean explicitIdentifiers) {
return initializeMultiPayload(recipients, excludedRecipients, buffer, explicitIdentifiers, 39);
}
private static InputStream initializeMultiPayload(final List<Recipient> recipients, final List<ServiceIdentifier> excludedRecipients, final byte[] buffer, final boolean explicitIdentifiers, final int payloadSize) {
private static InputStream initializeMultiPayload(final List<Recipient> recipients, final byte[] buffer, final boolean explicitIdentifiers, final int payloadSize) {
// initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN);
@@ -1007,10 +1002,9 @@ class MessageControllerTest {
bb.put(explicitIdentifiers ? (byte) 0x23 : (byte) 0x22); // version byte
// count varint
writeVarint(bb, recipients.size() + excludedRecipients.size());
writeVarint(bb, recipients.size());
recipients.forEach(recipient -> writeMultiPayloadRecipient(bb, recipient, explicitIdentifiers));
excludedRecipients.forEach(recipient -> writeMultiPayloadExcludedRecipient(bb, recipient, explicitIdentifiers));
// now write the actual message body (empty for now)
assert(payloadSize >= 32);
@@ -1269,24 +1263,20 @@ class MessageControllerTest {
.argumentsForNextParameter(false, true); // urgent
}
@ParameterizedTest
@MethodSource
void testMultiRecipientMessageWithGroupSendCredential(
List<ServiceIdentifier> includedRecipients,
List<ServiceIdentifier> excludedRecipients,
int expectedStatus,
int expectedMessagesSent) throws Exception {
final List<Recipient> recipients = new ArrayList<>();
includedRecipients.forEach(
serviceIdentifier -> multiRecipientTargetMap().get(serviceIdentifier).forEach(
(deviceId, registrationId) ->
recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
@Test
void testMultiRecipientMessageWithGroupSendEndorsements() throws Exception {
final List<Recipient> recipients = List.of(
new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipients, excludedRecipients, buffer, true);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
final AciServiceIdentifier senderId = new AciServiceIdentifier(UUID.randomUUID());
clock.pin(Instant.parse("2024-04-09T12:00:00.00Z"));
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
@@ -1296,77 +1286,106 @@ class MessageControllerTest {
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HeaderUtils.GROUP_SEND_CREDENTIAL, validGroupSendCredentialHeader(
senderId,
List.of(senderId, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID)))
.header(HeaderUtils.GROUP_SEND_TOKEN, validGroupSendTokenHeader(
senderId, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE));
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
verify(messageSender,
exactly(expectedMessagesSent))
exactly(3))
.sendMessage(
any(),
any(),
argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()),
eq(true));
if (expectedStatus == 200) {
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty()));
}
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty()));
}
private static Stream<Arguments> testMultiRecipientMessageWithGroupSendCredential() {
return Stream.of(
// All members present in included or excluded recipients: success, deliver to included recipients only
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(), 200, 3),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 200, 1),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 200, 2),
@Test
void testMultiRecipientMessageWithInvalidGroupSendEndorsements() throws Exception {
final List<Recipient> recipients = List.of(
new Recipient(NONEXISTENT_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
// No included recipients: request is bad
Arguments.of(List.of(), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 400, 0),
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipients, buffer, true);
final AciServiceIdentifier senderId = new AciServiceIdentifier(UUID.randomUUID());
// Some recipients both included and excluded: request is bad
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 400, 0),
clock.pin(Instant.parse("2024-04-09T12:00:00.00Z"));
// Included recipient not covered by credential: forbid
Arguments.of(List.of(NONEXISTENT_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(), 401, 0),
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1663798405641L)
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HeaderUtils.GROUP_SEND_TOKEN, validGroupSendTokenHeader(
senderId, List.of(MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE));
// Excluded recipient not covered by credential: forbid
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, SINGLE_DEVICE_ACI_ID), 401, 0),
// Some recipients not in included or excluded list: forbid
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(), 401, 0),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(), 401, 0),
// Substituting a PNI for an ACI is not allowed
Arguments.of(List.of(SINGLE_DEVICE_PNI_ID, MULTI_DEVICE_ACI_ID), List.of(), 401, 0));
assertThat("Unexpected response", response.getStatus(), is(equalTo(401)));
verifyNoMoreInteractions(messageSender);
}
private String validGroupSendCredentialHeader(AciServiceIdentifier sender, List<ServiceIdentifier> allGroupMembers) throws Exception {
@Test
void testMultiRecipientMessageWithExpiredGroupSendEndorsements() throws Exception {
final List<Recipient> recipients = List.of(
new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipients, buffer, true);
final AciServiceIdentifier senderId = new AciServiceIdentifier(UUID.randomUUID());
clock.pin(Instant.parse("2024-04-10T12:00:00.00Z"));
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1663798405641L)
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HeaderUtils.GROUP_SEND_TOKEN, validGroupSendTokenHeader(
senderId, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE));
assertThat("Unexpected response", response.getStatus(), is(equalTo(401)));
verifyNoMoreInteractions(messageSender);
}
private String validGroupSendTokenHeader(AciServiceIdentifier sender, List<ServiceIdentifier> recipients, Instant expiration) throws Exception {
final ServerPublicParams serverPublicParams = serverSecretParams.getPublicParams();
final GroupMasterKey groupMasterKey = new GroupMasterKey(new byte[32]);
final GroupSecretParams groupSecretParams = GroupSecretParams.deriveFromMasterKey(groupMasterKey);
final ClientZkGroupCipher clientZkGroupCipher = new ClientZkGroupCipher(groupSecretParams);
UuidCiphertext senderCiphertext = clientZkGroupCipher.encrypt(sender.toLibsignal());
List<UuidCiphertext> groupCiphertexts = allGroupMembers.stream()
.map(ServiceIdentifier::toLibsignal)
List<ServiceId> groupPlaintexts = Stream.concat(Stream.of(sender), recipients.stream()).map(ServiceIdentifier::toLibsignal).toList();
List<UuidCiphertext> groupCiphertexts = groupPlaintexts.stream()
.map(clientZkGroupCipher::encrypt)
.collect(Collectors.toList());
GroupSendCredentialResponse credentialResponse =
GroupSendCredentialResponse.issueCredential(groupCiphertexts, senderCiphertext, serverSecretParams);
GroupSendCredential credential =
credentialResponse.receive(
allGroupMembers.stream().map(ServiceIdentifier::toLibsignal).collect(Collectors.toList()),
.toList();
GroupSendDerivedKeyPair keyPair = GroupSendDerivedKeyPair.forExpiration(expiration, serverSecretParams);
GroupSendEndorsementsResponse endorsementsResponse =
GroupSendEndorsementsResponse.issue(groupCiphertexts, keyPair);
ReceivedEndorsements endorsements =
endorsementsResponse.receive(
groupPlaintexts,
sender.toLibsignal(),
serverPublicParams,
groupSecretParams);
GroupSendCredentialPresentation presentation = credential.present(serverPublicParams);
return Base64.getEncoder().encodeToString(presentation.serialize());
expiration.minus(Duration.ofDays(1)),
groupSecretParams,
serverPublicParams);
GroupSendFullToken token = endorsements.combinedEndorsement().toFullToken(groupSecretParams, expiration);
return Base64.getEncoder().encodeToString(token.serialize());
}
@ParameterizedTest
@@ -1408,7 +1427,7 @@ class MessageControllerTest {
.request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, List.of(), new byte[257<<10], true, 256<<10), MultiRecipientMessageProvider.MEDIA_TYPE));
.put(Entity.entity(initializeMultiPayload(recipients, new byte[257<<10], true, 256<<10), MultiRecipientMessageProvider.MEDIA_TYPE));
checkBadMultiRecipientResponse(response, 400);
}