Add API endpoints for waiting for newly-linked devices

This commit is contained in:
Jon Chambers
2024-10-10 10:11:32 -04:00
committed by GitHub
parent 087c2b61ee
commit 8c30a359e7
16 changed files with 793 additions and 122 deletions

View File

@@ -5,12 +5,14 @@
package org.whispersystems.textsecuregcm.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@@ -18,6 +20,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.amazonaws.util.Base64;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@@ -25,6 +28,7 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -35,6 +39,7 @@ import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.commons.lang3.RandomStringUtils;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
@@ -54,6 +59,7 @@ import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventLis
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
@@ -64,6 +70,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@@ -79,7 +86,7 @@ import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
@ExtendWith(DropwizardExtensionsSupport.class)
class DeviceControllerTest {
@@ -112,6 +119,7 @@ class DeviceControllerTest {
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager))
.addProvider(new DeviceLimitExceededExceptionMapper())
@@ -122,6 +130,7 @@ class DeviceControllerTest {
void setup() {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getWaitForLinkedDeviceLimiter()).thenReturn(rateLimiter);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
@@ -479,16 +488,17 @@ class DeviceControllerTest {
final Optional<ApnRegistrationId> apnRegistrationId,
final Optional<GcmRegistrationId> gcmRegistrationId) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
.get(LinkDeviceToken.class);
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@@ -506,7 +516,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(),
new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
@@ -539,21 +549,22 @@ class DeviceControllerTest {
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
.get(LinkDeviceToken.class);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(),
new AccountAttributes(true, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
@@ -931,4 +942,120 @@ class DeviceControllerTest {
verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
}
@Test
void waitForLinkedDevice() {
final DeviceInfo deviceInfo = new DeviceInfo(Device.PRIMARY_ID,
"Device name ciphertext".getBytes(StandardCharsets.UTF_8),
System.currentTimeMillis(),
System.currentTimeMillis());
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(200, response.getStatus());
final DeviceInfo retrievedDeviceInfo = response.readEntity(DeviceInfo.class);
assertEquals(deviceInfo.id(), retrievedDeviceInfo.id());
assertArrayEquals(deviceInfo.name(), retrievedDeviceInfo.name());
assertEquals(deviceInfo.created(), retrievedDeviceInfo.created());
assertEquals(deviceInfo.lastSeen(), retrievedDeviceInfo.lastSeen());
}
}
@Test
void waitForLinkedDeviceNoDeviceLinked() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(204, response.getStatus());
}
}
@Test
void waitForLinkedDeviceBadTokenIdentifier() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.queryParam("timeout", timeoutSeconds)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
private static List<Integer> waitForLinkedDeviceBadTimeout() {
return List.of(0, -1, 3601);
}
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTokenIdentifierLength(final String tokenIdentifier) {
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
private static List<String> waitForLinkedDeviceBadTokenIdentifierLength() {
return List.of(RandomStringUtils.randomAlphanumeric(DeviceController.MIN_TOKEN_IDENTIFIER_LENGTH - 1),
RandomStringUtils.randomAlphanumeric(DeviceController.MAX_TOKEN_IDENTIFIER_LENGTH + 1));
}
@Test
void waitForLinkedDeviceRateLimited() throws RateLimitExceededException {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(429, response.getStatus());
}
}
}