Update chat to send three search keys in one request to KT

This commit is contained in:
Katherine
2024-10-29 09:52:26 -04:00
committed by GitHub
parent 89292e238b
commit 712f3affd9
6 changed files with 240 additions and 160 deletions

View File

@@ -12,34 +12,29 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.ACI_PREFIX;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.E164_PREFIX;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.USERNAME_PREFIX;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.getFullSearchKeyByteString;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.net.HttpHeaders;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
@@ -59,6 +54,12 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.FullTreeHead;
import org.signal.keytransparency.client.SearchProof;
import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.TreeSearchResponse;
import org.signal.keytransparency.client.UpdateValue;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@@ -137,12 +138,25 @@ public class KeyTransparencyControllerTest {
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ParameterizedTest
@MethodSource
void searchSuccess(final Optional<String> e164, final Optional<byte[]> usernameHash, final int expectedNumClientCalls,
final Set<ByteString> expectedSearchKeys,
final Set<ByteString> expectedValues,
final List<Optional<ByteString>> expectedUnidentifiedAccessKey) {
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16)));
void searchSuccess(final Optional<String> e164, final Optional<byte[]> usernameHash) {
final TreeSearchResponse aciSearchResponse = TreeSearchResponse.newBuilder()
.setOpening(ByteString.copyFrom(TestRandomUtil.nextBytes(16)))
.setTreeHead(FullTreeHead.getDefaultInstance())
.setSearch(SearchProof.getDefaultInstance())
.setValue(UpdateValue.newBuilder()
.setValue(ByteString.copyFrom(TestRandomUtil.nextBytes(16)))
.build())
.build();
final SearchResponse.Builder searchResponseBuilder = SearchResponse.newBuilder()
.setTreeHead(FullTreeHead.getDefaultInstance())
.setAci(aciSearchResponse);
e164.ifPresent(ignored -> searchResponseBuilder.setE164(TreeSearchResponse.getDefaultInstance()));
usernameHash.ifPresent(ignored -> searchResponseBuilder.setUsernameHash(TreeSearchResponse.getDefaultInstance()));
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(searchResponseBuilder.build()));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/search")
@@ -158,54 +172,51 @@ public class KeyTransparencyControllerTest {
final KeyTransparencySearchResponse keyTransparencySearchResponse = response.readEntity(
KeyTransparencySearchResponse.class);
assertNotNull(keyTransparencySearchResponse.fullTreeHead());
assertNotNull(keyTransparencySearchResponse.aciSearchResponse());
usernameHash.ifPresentOrElse(
ignored -> assertTrue(keyTransparencySearchResponse.usernameHashSearchResponse().isPresent()),
() -> assertTrue(keyTransparencySearchResponse.usernameHashSearchResponse().isEmpty()));
assertEquals(aciSearchResponse, TreeSearchResponse.parseFrom(keyTransparencySearchResponse.aciSearchResponse()));
e164.ifPresentOrElse(ignored -> assertTrue(keyTransparencySearchResponse.e164SearchResponse().isPresent()),
() -> assertTrue(keyTransparencySearchResponse.e164SearchResponse().isEmpty()));
e164.ifPresent(ignored -> assertNotNull(keyTransparencySearchResponse.e164SearchResponse()));
usernameHash.ifPresent(ignored -> assertNotNull(keyTransparencySearchResponse.usernameHashSearchResponse()));
ArgumentCaptor<ByteString> valueArguments = ArgumentCaptor.forClass(ByteString.class);
ArgumentCaptor<ByteString> searchKeyArguments = ArgumentCaptor.forClass(ByteString.class);
ArgumentCaptor<Optional<ByteString>> unidentifiedAccessKeyArgument = ArgumentCaptor.forClass(Optional.class);
ArgumentCaptor<ByteString> aciArgument = ArgumentCaptor.forClass(ByteString.class);
ArgumentCaptor<ByteString> aciIdentityKeyArgument = ArgumentCaptor.forClass(ByteString.class);
ArgumentCaptor<Optional<ByteString>> usernameHashArgument = ArgumentCaptor.forClass(Optional.class);
ArgumentCaptor<Optional<E164SearchRequest>> e164Argument = ArgumentCaptor.forClass(Optional.class);
verify(keyTransparencyServiceClient, times(expectedNumClientCalls)).search(searchKeyArguments.capture(), valueArguments.capture(), unidentifiedAccessKeyArgument.capture(), eq(Optional.of(3L)), eq(Optional.of(4L)),
verify(keyTransparencyServiceClient).search(aciArgument.capture(), aciIdentityKeyArgument.capture(),
usernameHashArgument.capture(), e164Argument.capture(), eq(Optional.of(3L)), eq(Optional.of(4L)),
eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
assertEquals(expectedSearchKeys, new HashSet<>(searchKeyArguments.getAllValues()));
assertEquals(expectedValues, new HashSet<>(valueArguments.getAllValues()));
assertEquals(expectedUnidentifiedAccessKey, unidentifiedAccessKeyArgument.getAllValues());
assertArrayEquals(ACI.toCompactByteArray(), aciArgument.getValue().toByteArray());
assertArrayEquals(ACI_IDENTITY_KEY.serialize(), aciIdentityKeyArgument.getValue().toByteArray());
if (usernameHash.isPresent()) {
assertArrayEquals(USERNAME_HASH, usernameHashArgument.getValue().orElseThrow().toByteArray());
} else {
assertTrue(usernameHashArgument.getValue().isEmpty());
}
if (e164.isPresent()) {
final E164SearchRequest expected = E164SearchRequest.newBuilder()
.setE164(e164.get())
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey.get()))
.build();
assertEquals(expected, e164Argument.getValue().orElseThrow());
} else {
assertTrue(e164Argument.getValue().isEmpty());
}
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
private static Stream<Arguments> searchSuccess() {
final byte[] aciBytes = ACI.toCompactByteArray();
final ByteString aciValueByteString = ByteString.copyFrom(aciBytes);
final byte[] aciIdentityKeyBytes = ACI_IDENTITY_KEY.serialize();
final ByteString aciIdentityKeyValueByteString = ByteString.copyFrom(aciIdentityKeyBytes);
return Stream.of(
// Only looking up ACI; ACI identity key should be the only value provided; no UAK
Arguments.of(Optional.empty(), Optional.empty(), 1,
Set.of(getFullSearchKeyByteString(ACI_PREFIX, aciBytes)),
Set.of(aciIdentityKeyValueByteString),
List.of(Optional.empty())),
// Looking up ACI and username hash; ACI identity key and ACI should be the values provided; no UAK
Arguments.of(Optional.empty(), Optional.of(USERNAME_HASH), 2,
Set.of(getFullSearchKeyByteString(ACI_PREFIX, aciBytes),
getFullSearchKeyByteString(USERNAME_PREFIX, USERNAME_HASH)),
Set.of(aciIdentityKeyValueByteString, aciValueByteString),
List.of(Optional.empty(), Optional.empty())),
// Looking up ACI and phone number; ACI identity key and ACI should be the values provided; must provide UAK
Arguments.of(Optional.of(NUMBER), Optional.empty(), 2,
Set.of(getFullSearchKeyByteString(ACI_PREFIX, aciBytes),
getFullSearchKeyByteString(E164_PREFIX, NUMBER.getBytes(StandardCharsets.UTF_8))),
Set.of(aciValueByteString, aciIdentityKeyValueByteString),
List.of(Optional.empty(), Optional.of(ByteString.copyFrom(UNIDENTIFIED_ACCESS_KEY))))
Arguments.of(Optional.of(NUMBER), Optional.empty()),
Arguments.of(Optional.empty(), Optional.of(USERNAME_HASH)),
Arguments.of(Optional.of(NUMBER), Optional.of(USERNAME_HASH))
);
}
@@ -226,7 +237,7 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest
@MethodSource
void searchGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), any()))
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus))));
final Invocation.Builder request = resources.getJerseyTest()
@@ -236,7 +247,7 @@ public class KeyTransparencyControllerTest {
Entity.json(createRequestJson(new KeyTransparencySearchRequest(ACI, Optional.empty(), Optional.empty(),
ACI_IDENTITY_KEY, Optional.empty(), Optional.empty(), Optional.empty()))))) {
assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), any());
verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), any(), any());
}
}