Add API endpoints for waiting for transfer archives

This commit is contained in:
Jon Chambers
2024-10-11 13:04:51 -04:00
committed by Jon Chambers
parent 7ff48155d6
commit 73fb1fc2ed
4 changed files with 271 additions and 8 deletions

View File

@@ -20,7 +20,6 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.amazonaws.util.Base64;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@@ -28,6 +27,9 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -64,7 +66,9 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest;
import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@@ -132,6 +136,8 @@ class DeviceControllerTest {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getWaitForLinkedDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getUploadTransferArchiveLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getWaitForTransferArchiveLimiter()).thenReturn(rateLimiter);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
@@ -957,7 +963,7 @@ class DeviceControllerTest {
System.currentTimeMillis(),
System.currentTimeMillis());
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
@@ -980,7 +986,7 @@ class DeviceControllerTest {
@Test
void waitForLinkedDeviceNoDeviceLinked() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
@@ -997,7 +1003,7 @@ class DeviceControllerTest {
@Test
void waitForLinkedDeviceBadTokenIdentifier() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
@@ -1015,7 +1021,7 @@ class DeviceControllerTest {
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
@@ -1052,7 +1058,7 @@ class DeviceControllerTest {
@Test
void waitForLinkedDeviceRateLimited() throws RateLimitExceededException {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID);
@@ -1065,4 +1071,158 @@ class DeviceControllerTest {
assertEquals(429, response.getStatus());
}
}
@Test
void recordTransferArchiveUploaded() {
final byte deviceId = Device.PRIMARY_ID + 1;
final Instant deviceCreated = Instant.now().truncatedTo(ChronoUnit.MILLIS);
final RemoteAttachment transferArchive =
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive))
.thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new TransferArchiveUploadedRequest(deviceId, deviceCreated.toEpochMilli(), transferArchive),
MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(204, response.getStatus());
verify(accountsManager)
.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive);
}
}
@ParameterizedTest
@MethodSource
void recordTransferArchiveUploadedBadRequest(final TransferArchiveUploadedRequest request) {
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(422, response.getStatus());
verify(accountsManager, never())
.recordTransferArchiveUpload(any(), anyByte(), any(), any());
}
}
@SuppressWarnings("DataFlowIssue")
private static List<TransferArchiveUploadedRequest> recordTransferArchiveUploadedBadRequest() {
final RemoteAttachment validTransferArchive =
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("archive".getBytes(StandardCharsets.UTF_8)));
return List.of(
// Invalid device ID
new TransferArchiveUploadedRequest((byte) -1, System.currentTimeMillis(), validTransferArchive),
// Invalid "created at" timestamp
new TransferArchiveUploadedRequest(Device.PRIMARY_ID, -1, validTransferArchive),
// Missing CDN number
new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(),
new RemoteAttachment(null, Base64.getUrlEncoder().encodeToString("archive".getBytes(StandardCharsets.UTF_8)))),
// Bad attachment key
new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(),
new RemoteAttachment(3, "This is not a valid base64 string"))
);
}
@Test
void recordTransferArchiveRateLimited() {
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null)));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(),
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)))),
MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(429, response.getStatus());
verify(accountsManager, never())
.recordTransferArchiveUpload(any(), anyByte(), any(), any());
}
}
@Test
void waitForTransferArchive() {
final RemoteAttachment transferArchive =
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive)));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(200, response.getStatus());
assertEquals(transferArchive, response.readEntity(RemoteAttachment.class));
}
}
@Test
void waitForTransferArchiveNoArchiveUploaded() {
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(204, response.getStatus());
}
}
@ParameterizedTest
@MethodSource
void waitForTransferArchiveBadTimeout(final int timeoutSeconds) {
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
.queryParam("timeout", timeoutSeconds)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
private static List<Integer> waitForTransferArchiveBadTimeout() {
return List.of(0, -1, 3601);
}
@Test
void waitForTransferArchiveRateLimited() {
when(rateLimiter.validateAsync(anyString()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null)));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(429, response.getStatus());
}
}
}