Group Send Credential support in chat

This commit is contained in:
Jonathan Klabunde Tomer
2024-01-04 11:38:57 -08:00
committed by GitHub
parent 195f23c347
commit e1ad25cee0
14 changed files with 337 additions and 108 deletions

View File

@@ -40,12 +40,12 @@ import org.signal.libsignal.zkgroup.auth.ServerZkAuthOperations;
import org.signal.libsignal.zkgroup.calllinks.CallLinkAuthCredentialResponse;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class)
@@ -198,7 +198,7 @@ class CertificateControllerTest {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes()))
.get();
assertEquals(response.getStatus(), 401);

View File

@@ -48,7 +48,6 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
@@ -73,6 +72,7 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
@ExtendWith(DropwizardExtensionsSupport.class)
class KeysControllerTest {
@@ -494,7 +494,7 @@ class KeysControllerTest {
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
@@ -518,7 +518,7 @@ class KeysControllerTest {
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get();
assertThat(result).isNotNull();
@@ -530,7 +530,7 @@ class KeysControllerTest {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("9999".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("9999".getBytes()))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@@ -542,7 +542,7 @@ class KeysControllerTest {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, "$$$$$$$$$")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, "$$$$$$$$$")
.get();
assertThat(response.getStatus()).isEqualTo(401);

View File

@@ -81,10 +81,20 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.util.Hex;
import org.signal.libsignal.zkgroup.groups.ClientZkGroupCipher;
import org.signal.libsignal.zkgroup.groups.GroupMasterKey;
import org.signal.libsignal.zkgroup.groups.GroupSecretParams;
import org.signal.libsignal.zkgroup.groups.UuidCiphertext;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredential;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialPresentation;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialResponse;
import org.signal.libsignal.zkgroup.ServerPublicParams;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
@@ -125,6 +135,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
@@ -137,15 +148,19 @@ import reactor.core.scheduler.Schedulers;
class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
private static final ServiceIdentifier SINGLE_DEVICE_ACI_ID = new AciServiceIdentifier(SINGLE_DEVICE_UUID);
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
private static final ServiceIdentifier SINGLE_DEVICE_PNI_ID = new PniServiceIdentifier(SINGLE_DEVICE_PNI);
private static final byte SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
private static final ServiceIdentifier MULTI_DEVICE_ACI_ID = new AciServiceIdentifier(MULTI_DEVICE_UUID);
private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID();
private static final ServiceIdentifier MULTI_DEVICE_PNI_ID = new PniServiceIdentifier(MULTI_DEVICE_PNI);
private static final byte MULTI_DEVICE_ID1 = 1;
private static final byte MULTI_DEVICE_ID2 = 2;
private static final byte MULTI_DEVICE_ID3 = 3;
@@ -157,6 +172,8 @@ class MessageControllerTest {
private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444;
private static final UUID NONEXISTENT_UUID = UUID.randomUUID();
private static final ServiceIdentifier NONEXISTENT_ACI_ID = new AciServiceIdentifier(NONEXISTENT_UUID);
private static final ServiceIdentifier NONEXISTENT_PNI_ID = new PniServiceIdentifier(NONEXISTENT_UUID);
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
@@ -178,6 +195,7 @@ class MessageControllerTest {
private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService();
private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
private static final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private static final ServerSecretParams serverSecretParams = ServerSecretParams.generate();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
@@ -189,7 +207,8 @@ class MessageControllerTest {
.addResource(
new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager,
messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor,
messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager))
messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager,
serverSecretParams))
.build();
@BeforeEach
@@ -213,19 +232,19 @@ class MessageControllerTest {
Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID,
UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(SINGLE_DEVICE_ACI_ID)).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(SINGLE_DEVICE_PNI_ID)).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(MULTI_DEVICE_ACI_ID)).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(MULTI_DEVICE_PNI_ID)).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(NONEXISTENT_ACI_ID)).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(NONEXISTENT_PNI_ID)).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(SINGLE_DEVICE_ACI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(SINGLE_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_ACI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount)));
final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration =
@@ -381,7 +400,7 @@ class MessageControllerTest {
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
@@ -904,7 +923,7 @@ class MessageControllerTest {
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, (byte) 1, 1, new String(contentBytes))), false, true,
System.currentTimeMillis()),
@@ -965,7 +984,21 @@ class MessageControllerTest {
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
}
private static void writeMultiPayloadExcludedRecipient(final ByteBuffer bb, final ServiceIdentifier id, final boolean useExplicitIdentifier) {
if (useExplicitIdentifier) {
bb.put(id.toFixedWidthByteArray());
} else {
bb.put(UUIDUtil.toBytes(id.uuid()));
}
bb.put((byte) 0);
}
private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer, final boolean explicitIdentifiers) {
return initializeMultiPayload(recipients, List.of(), buffer, explicitIdentifiers);
}
private static InputStream initializeMultiPayload(List<Recipient> recipients, List<ServiceIdentifier> excludedRecipients, byte[] buffer, final boolean explicitIdentifiers) {
// initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN);
@@ -974,17 +1007,15 @@ class MessageControllerTest {
bb.put(explicitIdentifiers ? (byte) 0x23 : (byte) 0x22); // version byte
// count varint
int nRecip = recipients.size();
int nRecip = recipients.size() + excludedRecipients.size();
while (nRecip > 127) {
bb.put((byte) (nRecip & 0x7F | 0x80));
nRecip = nRecip >> 7;
}
bb.put((byte)(nRecip & 0x7F));
Iterator<Recipient> it = recipients.iterator();
while (it.hasNext()) {
writeMultiPayloadRecipient(bb, it.next(), explicitIdentifiers);
}
recipients.forEach(recipient -> writeMultiPayloadRecipient(bb, recipient, explicitIdentifiers));
excludedRecipients.forEach(recipient -> writeMultiPayloadExcludedRecipient(bb, recipient, explicitIdentifiers));
// now write the actual message body (empty for now)
bb.put(new byte[39]); // payload (variable but >= 32, 39 bytes here)
@@ -1062,8 +1093,17 @@ class MessageControllerTest {
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
Invocation.Builder bldr = resources
// build correct or incorrect access header
final String accessHeader;
if (authorize) {
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
accessHeader = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
} else {
accessHeader = "BBBBBBBBBBBBBBBBBBBBBB==";
}
// make the PUT request
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
@@ -1071,17 +1111,9 @@ class MessageControllerTest {
.queryParam("story", isStory)
.queryParam("urgent", urgent)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME");
// add access header if needed
if (authorize) {
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
}
// make the PUT request
Response response = bldr.put(entity);
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader)
.put(entity);
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
verify(messageSender,
@@ -1105,18 +1137,18 @@ class MessageControllerTest {
private static Map<ServiceIdentifier, Map<Byte, Integer>> multiRecipientTargetMap() {
return
Map.of(
new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
SINGLE_DEVICE_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
SINGLE_DEVICE_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
MULTI_DEVICE_ACI_ID,
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2),
new PniServiceIdentifier(MULTI_DEVICE_PNI),
MULTI_DEVICE_PNI_ID,
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2),
new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
NONEXISTENT_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
NONEXISTENT_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
);
}
@@ -1137,16 +1169,16 @@ class MessageControllerTest {
@SuppressWarnings("unused")
private static ArgumentSets testMultiRecipientMessageNoPni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, SINGLE_DEVICE_ACI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, MULTI_DEVICE_ACI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsAci =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID));
submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeAci =
submap(
targets,
new AciServiceIdentifier(SINGLE_DEVICE_UUID),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new AciServiceIdentifier(NONEXISTENT_UUID));
SINGLE_DEVICE_ACI_ID,
MULTI_DEVICE_ACI_ID,
NONEXISTENT_ACI_ID);
final boolean auth = true;
final boolean unauth = false;
@@ -1186,18 +1218,18 @@ class MessageControllerTest {
private static ArgumentSets testMultiRecipientMessagePni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, SINGLE_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAciAndPni = submap(
targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI));
targets, SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, MULTI_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsMixed =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI));
submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeMixed =
submap(
targets,
new PniServiceIdentifier(SINGLE_DEVICE_PNI),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new PniServiceIdentifier(NONEXISTENT_UUID));
SINGLE_DEVICE_PNI_ID,
MULTI_DEVICE_ACI_ID,
NONEXISTENT_PNI_ID);
final boolean auth = true;
final boolean unauth = false;
@@ -1232,13 +1264,113 @@ class MessageControllerTest {
.argumentsForNextParameter(false, true); // urgent
}
@ParameterizedTest
@MethodSource
void testMultiRecipientMessageWithGroupSendCredential(
List<ServiceIdentifier> includedRecipients,
List<ServiceIdentifier> excludedRecipients,
int expectedStatus,
int expectedMessagesSent) throws Exception {
final List<Recipient> recipients = new ArrayList<>();
includedRecipients.forEach(
serviceIdentifier -> multiRecipientTargetMap().get(serviceIdentifier).forEach(
(deviceId, registrationId) ->
recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipients, excludedRecipients, buffer, true);
final AciServiceIdentifier senderId = new AciServiceIdentifier(UUID.randomUUID());
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1663798405641L)
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HeaderUtils.GROUP_SEND_CREDENTIAL, validGroupSendCredentialHeader(
senderId,
List.of(senderId, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID)))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE));
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
verify(messageSender,
exactly(expectedMessagesSent))
.sendMessage(
any(),
any(),
argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()),
eq(true));
if (expectedStatus == 200) {
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty()));
}
}
private static Stream<Arguments> testMultiRecipientMessageWithGroupSendCredential() {
return Stream.of(
// All members present in included or excluded recipients: success, deliver to included recipients only
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(), 200, 3),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 200, 1),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 200, 2),
// No included recipients: request is bad
Arguments.of(List.of(), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 400, 0),
// Some recipients both included and excluded: request is bad
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 400, 0),
// Included recipient not covered by credential: forbid
Arguments.of(List.of(NONEXISTENT_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(), 401, 0),
// Excluded recipient not covered by credential: forbid
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, SINGLE_DEVICE_ACI_ID), 401, 0),
// Some recipients not in included or excluded list: forbid
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(), 401, 0),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(), 401, 0),
// Substituting a PNI for an ACI is not allowed
Arguments.of(List.of(SINGLE_DEVICE_PNI_ID, MULTI_DEVICE_ACI_ID), List.of(), 401, 0));
}
private String validGroupSendCredentialHeader(AciServiceIdentifier sender, List<ServiceIdentifier> allGroupMembers) throws Exception {
final ServerPublicParams serverPublicParams = serverSecretParams.getPublicParams();
final GroupMasterKey groupMasterKey = new GroupMasterKey(new byte[32]);
final GroupSecretParams groupSecretParams = GroupSecretParams.deriveFromMasterKey(groupMasterKey);
final ClientZkGroupCipher clientZkGroupCipher = new ClientZkGroupCipher(groupSecretParams);
UuidCiphertext senderCiphertext = clientZkGroupCipher.encrypt(sender.toLibsignal());
List<UuidCiphertext> groupCiphertexts = allGroupMembers.stream()
.map(ServiceIdentifier::toLibsignal)
.map(clientZkGroupCipher::encrypt)
.collect(Collectors.toList());
GroupSendCredentialResponse credentialResponse =
GroupSendCredentialResponse.issueCredential(groupCiphertexts, senderCiphertext, serverSecretParams);
GroupSendCredential credential =
credentialResponse.receive(
allGroupMembers.stream().map(ServiceIdentifier::toLibsignal).collect(Collectors.toList()),
sender.toLibsignal(),
serverPublicParams,
groupSecretParams);
GroupSendCredentialPresentation presentation = credential.present(serverPublicParams);
return Base64.getEncoder().encodeToString(presentation.serialize());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
Response response = resources
.getJerseyTest()
@@ -1249,7 +1381,7 @@ class MessageControllerTest {
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE));
checkBadMultiRecipientResponse(response, 400);
@@ -1266,7 +1398,7 @@ class MessageControllerTest {
.target(String.format("/v1/messages/%s", unknownUUID))
.queryParam("story", "true")
.request()
.header(OptionalAccess.UNIDENTIFIED, accessBytes)
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes)
.put(Entity.entity(list, MediaType.APPLICATION_JSON_TYPE));
assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200)));
@@ -1278,13 +1410,13 @@ class MessageControllerTest {
final Recipient r1;
if (known) {
r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
r1 = new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
} else {
r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), (byte) 99, 999, new byte[48]);
}
Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
Recipient r3 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
Recipient r2 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
Recipient r3 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
List<Recipient> recipients = List.of(r1, r2, r3);
@@ -1307,7 +1439,7 @@ class MessageControllerTest {
.queryParam("story", story)
.request()
.header(HttpHeaders.USER_AGENT, "Test User Agent")
.header(OptionalAccess.UNIDENTIFIED, accessBytes);
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes);
// make the PUT request
Response response = bldr.put(entity);
@@ -1363,7 +1495,7 @@ class MessageControllerTest {
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
final Response response = invocationBuilder.put(entity);
@@ -1381,8 +1513,8 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessageMismatchedDevices() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
Arguments.of(MULTI_DEVICE_ACI_ID),
Arguments.of(MULTI_DEVICE_PNI_ID));
}
@ParameterizedTest
@@ -1410,7 +1542,7 @@ class MessageControllerTest {
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
final Response response = invocationBuilder.put(entity);
@@ -1429,8 +1561,8 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessageStaleDevices() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
Arguments.of(MULTI_DEVICE_ACI_ID),
Arguments.of(MULTI_DEVICE_PNI_ID));
}
@ParameterizedTest
@@ -1460,7 +1592,7 @@ class MessageControllerTest {
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
doThrow(NotPushRegisteredException.class)
.when(messageSender).sendMessage(any(), any(), any(), anyBoolean());
@@ -1473,13 +1605,13 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessage404() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
Arguments.of(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2),
Arguments.of(MULTI_DEVICE_PNI_ID, MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
}
@Test
void sendMultiRecipientMessageStoryRateLimited() {
final List<Recipient> recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
final List<Recipient> recipients = List.of(new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
@@ -1498,7 +1630,7 @@ class MessageControllerTest {
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
when(rateLimiter.validateAsync(any(UUID.class)))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77), true)));

View File

@@ -75,7 +75,6 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest;
import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext;
import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@@ -107,6 +106,7 @@ import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@@ -279,7 +279,7 @@ class ProfileControllerTest {
final BaseProfileResponse profile = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.get(BaseProfileResponse.class);
assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_IDENTITY_KEY);
@@ -295,7 +295,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@@ -306,7 +306,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest()
.target("/v1/profile/" + UUID.randomUUID())
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@@ -351,7 +351,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest()
.target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO)
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@@ -365,7 +365,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO)
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@@ -1140,7 +1140,7 @@ class ProfileControllerTest {
private static Stream<Arguments> testGetProfileWithExpiringProfileKeyCredential() {
return Stream.of(
Arguments.of(new MultivaluedHashMap<>(Map.of(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))),
Arguments.of(new MultivaluedHashMap<>(Map.of(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))),
Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)))),
Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))))
);
@@ -1184,7 +1184,7 @@ class ProfileControllerTest {
HexFormat.of().formatHex(credentialRequest.serialize())))
.queryParam("credentialType", "expiringProfileKey")
.request()
.headers(new MultivaluedHashMap<>(Map.of(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY))))
.headers(new MultivaluedHashMap<>(Map.of(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY))))
.get();
assertEquals(400, response.getStatus());