Implement key transparency endpoints using simple-grpc

This commit is contained in:
Katherine
2025-06-24 14:01:35 -04:00
committed by GitHub
parent 51773f5709
commit 059caa4c57
8 changed files with 562 additions and 116 deletions

View File

@@ -35,12 +35,8 @@ import jakarta.ws.rs.client.WebTarget;
import jakarta.ws.rs.core.Response;
import java.io.UncheckedIOException;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
@@ -54,8 +50,10 @@ import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.signal.keytransparency.client.CondensedTreeSearchResponse;
import org.signal.keytransparency.client.DistinguishedResponse;
import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.FullTreeHead;
import org.signal.keytransparency.client.MonitorResponse;
import org.signal.keytransparency.client.SearchProof;
import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.UpdateValue;
@@ -81,16 +79,16 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
public class KeyTransparencyControllerTest {
private static final String NUMBER = PhoneNumberUtil.getInstance().format(
public static final String NUMBER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
private static final AciServiceIdentifier ACI = new AciServiceIdentifier(UUID.randomUUID());
private static final byte[] USERNAME_HASH = TestRandomUtil.nextBytes(20);
public static final AciServiceIdentifier ACI = new AciServiceIdentifier(UUID.randomUUID());
public static final byte[] USERNAME_HASH = TestRandomUtil.nextBytes(20);
private static final TestRemoteAddressFilterProvider TEST_REMOTE_ADDRESS_FILTER_PROVIDER
= new TestRemoteAddressFilterProvider("127.0.0.1");
private static final IdentityKey ACI_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
public static final IdentityKey ACI_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
private static final byte[] COMMITMENT_INDEX = new byte[32];
private static final byte[] UNIDENTIFIED_ACCESS_KEY = new byte[16];
public static final byte[] UNIDENTIFIED_ACCESS_KEY = new byte[16];
private final KeyTransparencyServiceClient keyTransparencyServiceClient = mock(KeyTransparencyServiceClient.class);
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter searchRatelimiter = mock(RateLimiter.class);
@@ -141,8 +139,8 @@ public class KeyTransparencyControllerTest {
e164.ifPresent(ignored -> searchResponseBuilder.setE164(CondensedTreeSearchResponse.getDefaultInstance()));
usernameHash.ifPresent(ignored -> searchResponseBuilder.setUsernameHash(CondensedTreeSearchResponse.getDefaultInstance()));
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong(), any()))
.thenReturn(CompletableFuture.completedFuture(searchResponseBuilder.build().toByteArray()));
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong()))
.thenReturn(searchResponseBuilder.build());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/search")
@@ -167,8 +165,7 @@ public class KeyTransparencyControllerTest {
ArgumentCaptor<Optional<E164SearchRequest>> e164Argument = ArgumentCaptor.forClass(Optional.class);
verify(keyTransparencyServiceClient).search(aciArgument.capture(), aciIdentityKeyArgument.capture(),
usernameHashArgument.capture(), e164Argument.capture(), eq(Optional.of(3L)), eq(4L),
eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
usernameHashArgument.capture(), e164Argument.capture(), eq(Optional.of(3L)), eq(4L));
assertArrayEquals(ACI.toCompactByteArray(), aciArgument.getValue().toByteArray());
assertArrayEquals(ACI_IDENTITY_KEY.serialize(), aciIdentityKeyArgument.getValue().toByteArray());
@@ -218,8 +215,8 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest
@MethodSource
void searchGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus))));
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong()))
.thenThrow(new StatusRuntimeException(grpcStatus));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/search")
@@ -228,7 +225,7 @@ public class KeyTransparencyControllerTest {
Entity.json(createRequestJson(new KeyTransparencySearchRequest(ACI, Optional.empty(), Optional.empty(),
ACI_IDENTITY_KEY, Optional.empty(), Optional.empty(), 4L))))) {
assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), anyLong(), any());
verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), anyLong());
}
}
@@ -295,8 +292,8 @@ public class KeyTransparencyControllerTest {
@Test
void monitorSuccess() {
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16)));
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong()))
.thenReturn(MonitorResponse.getDefaultInstance());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/monitor")
@@ -314,7 +311,7 @@ public class KeyTransparencyControllerTest {
assertNotNull(keyTransparencyMonitorResponse.serializedResponse());
verify(keyTransparencyServiceClient, times(1)).monitor(
any(), any(), any(), eq(3L), eq(4L), eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
any(), any(), any(), eq(3L), eq(4L));
}
}
@@ -337,8 +334,8 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest
@MethodSource
void monitorGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus))));
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong()))
.thenThrow(new StatusRuntimeException(grpcStatus));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/monitor")
@@ -349,7 +346,7 @@ public class KeyTransparencyControllerTest {
new KeyTransparencyMonitorRequest.AciMonitor(ACI, 3, COMMITMENT_INDEX),
Optional.empty(), Optional.empty(), 3L, 4L))))) {
assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), anyLong(), anyLong(), any());
verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), anyLong(), anyLong());
}
}
@@ -500,8 +497,8 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest
@CsvSource(", 1")
void distinguishedSuccess(@Nullable Long lastTreeHeadSize) {
when(keyTransparencyServiceClient.getDistinguishedKey(any(), any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16)));
when(keyTransparencyServiceClient.getDistinguishedKey(any()))
.thenReturn(DistinguishedResponse.getDefaultInstance());
WebTarget webTarget = resources.getJerseyTest()
.target("/v1/key-transparency/distinguished");
@@ -518,8 +515,7 @@ public class KeyTransparencyControllerTest {
assertNotNull(distinguishedKeyResponse.serializedResponse());
verify(keyTransparencyServiceClient, times(1))
.getDistinguishedKey(eq(Optional.ofNullable(lastTreeHeadSize)),
eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
.getDistinguishedKey(eq(Optional.ofNullable(lastTreeHeadSize)));
}
}
@@ -538,15 +534,15 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest
@MethodSource
void distinguishedGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.getDistinguishedKey(any(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus))));
when(keyTransparencyServiceClient.getDistinguishedKey(any()))
.thenThrow(new StatusRuntimeException(grpcStatus));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/distinguished")
.request();
try (Response response = request.get()) {
assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient).getDistinguishedKey(any(), any());
verify(keyTransparencyServiceClient).getDistinguishedKey(any());
}
}
@@ -561,8 +557,8 @@ public class KeyTransparencyControllerTest {
@Test
void distinguishedInvalidRequest() {
when(keyTransparencyServiceClient.getDistinguishedKey(any(), any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16)));
when(keyTransparencyServiceClient.getDistinguishedKey(any()))
.thenReturn(DistinguishedResponse.getDefaultInstance());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/distinguished")