Add API endpoints for waiting for account restoration requests

This commit is contained in:
Jon Chambers
2024-10-24 12:25:40 -04:00
committed by GitHub
parent 5c4cafcb6f
commit 324913d2da
6 changed files with 298 additions and 12 deletions

View File

@@ -51,6 +51,7 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
@@ -61,6 +62,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest;
import org.whispersystems.textsecuregcm.entities.LinkDeviceResponse;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
@@ -1019,7 +1021,7 @@ class DeviceControllerTest {
}
@ParameterizedTest
@MethodSource
@ValueSource(ints = {0, -1, 3601})
void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
@@ -1034,10 +1036,6 @@ class DeviceControllerTest {
}
}
private static List<Integer> waitForLinkedDeviceBadTimeout() {
return List.of(0, -1, 3601);
}
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTokenIdentifierLength(final String tokenIdentifier) {
@@ -1194,7 +1192,7 @@ class DeviceControllerTest {
}
@ParameterizedTest
@MethodSource
@ValueSource(ints = {0, -1, 3601})
void waitForTransferArchiveBadTimeout(final int timeoutSeconds) {
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
@@ -1207,10 +1205,6 @@ class DeviceControllerTest {
}
}
private static List<Integer> waitForTransferArchiveBadTimeout() {
return List.of(0, -1, 3601);
}
@Test
void waitForTransferArchiveRateLimited() {
when(rateLimiter.validateAsync(anyString()))
@@ -1225,4 +1219,101 @@ class DeviceControllerTest {
assertEquals(429, response.getStatus());
}
}
@Test
void recordRestoreAccountRequest() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.LOCAL_BACKUP);
when(accountsManager.recordRestoreAccountRequest(token, restoreAccountRequest))
.thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.put(Entity.json(restoreAccountRequest))) {
assertEquals(204, response.getStatus());
}
}
@Test
void recordRestoreAccountRequestBadToken() {
final String token = RandomStringUtils.randomAlphanumeric(128);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.LOCAL_BACKUP);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.put(Entity.json(restoreAccountRequest))) {
assertEquals(400, response.getStatus());
}
}
@Test
void recordRestoreAccountRequestInvalidRequest() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest = new RestoreAccountRequest(null);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.put(Entity.json(restoreAccountRequest))) {
assertEquals(422, response.getStatus());
}
}
@Test
void waitForDeviceTransferRequest() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.LOCAL_BACKUP);
when(accountsManager.waitForRestoreAccountRequest(eq(token), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(restoreAccountRequest)));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.get()) {
assertEquals(200, response.getStatus());
assertEquals(restoreAccountRequest, response.readEntity(RestoreAccountRequest.class));
}
}
@Test
void waitForDeviceTransferRequestNoRequestIssued() {
final String token = RandomStringUtils.randomAlphanumeric(16);
when(accountsManager.waitForRestoreAccountRequest(eq(token), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.get()) {
assertEquals(204, response.getStatus());
}
}
@ParameterizedTest
@ValueSource(ints = {0, -1, 3601})
void waitForDeviceTransferRequestBadTimeout(final int timeoutSeconds) {
final String token = RandomStringUtils.randomAlphanumeric(16);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.queryParam("timeout", timeoutSeconds)
.request()
.get()) {
assertEquals(400, response.getStatus());
}
}
}

View File

@@ -5,11 +5,13 @@
package org.whispersystems.textsecuregcm.storage;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
@@ -34,7 +36,7 @@ import static org.mockito.Mockito.when;
// ThreadMode.SEPARATE_THREAD protects against hangs in the remote Redis calls, as this mode allows the test code to be
// preempted by the timeout check
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class AccountsManagerTransferArchiveIntegrationTest {
public class AccountsManagerDeviceTransferIntegrationTest {
@RegisterExtension
static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build();
@@ -144,4 +146,41 @@ public class AccountsManagerTransferArchiveIntegrationTest {
assertEquals(Optional.empty(),
accountsManager.waitForTransferArchive(account, device, Duration.ofMillis(1)).join());
}
@Test
void waitForRestoreAccountRequest() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.DEVICE_TRANSFER);
final CompletableFuture<Optional<RestoreAccountRequest>> displacedFuture =
accountsManager.waitForRestoreAccountRequest(token, Duration.ofSeconds(5));
final CompletableFuture<Optional<RestoreAccountRequest>> activeFuture =
accountsManager.waitForRestoreAccountRequest(token, Duration.ofSeconds(5));
assertEquals(Optional.empty(), displacedFuture.join());
accountsManager.recordRestoreAccountRequest(token, restoreAccountRequest).join();
assertEquals(Optional.of(restoreAccountRequest), activeFuture.join());
}
@Test
void waitForRestoreAccountRequestAlreadyRequested() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.DEVICE_TRANSFER);
accountsManager.recordRestoreAccountRequest(token, restoreAccountRequest).join();
assertEquals(Optional.of(restoreAccountRequest),
accountsManager.waitForRestoreAccountRequest(token, Duration.ofSeconds(5)).join());
}
@Test
void waitForRestoreAccountRequestTimeout() {
assertEquals(Optional.empty(),
accountsManager.waitForRestoreAccountRequest(RandomStringUtils.randomAlphanumeric(16), Duration.ofMillis(1)).join());
}
}