mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-22 23:28:04 +01:00
Add API endpoints for waiting for newly-linked devices
This commit is contained in:
@@ -5,12 +5,14 @@
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyByte;
|
||||
import static org.mockito.Mockito.anyString;
|
||||
import static org.mockito.Mockito.clearInvocations;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
@@ -18,6 +20,7 @@ import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.amazonaws.util.Base64;
|
||||
import com.google.common.net.HttpHeaders;
|
||||
import io.dropwizard.auth.AuthValueFactoryProvider;
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
@@ -25,6 +28,7 @@ import io.dropwizard.testing.junit5.ResourceExtension;
|
||||
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
|
||||
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -35,6 +39,7 @@ import java.util.stream.Stream;
|
||||
import javax.ws.rs.client.Entity;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.Response;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.glassfish.jersey.server.ServerProperties;
|
||||
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
@@ -54,6 +59,7 @@ import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventLis
|
||||
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
|
||||
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
|
||||
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
|
||||
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
|
||||
@@ -64,6 +70,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
|
||||
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
|
||||
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
@@ -79,7 +86,7 @@ import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
import org.whispersystems.textsecuregcm.util.TestClock;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
import org.whispersystems.textsecuregcm.util.VerificationCode;
|
||||
import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class DeviceControllerTest {
|
||||
@@ -112,6 +119,7 @@ class DeviceControllerTest {
|
||||
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
|
||||
.addProvider(AuthHelper.getAuthFilter())
|
||||
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
|
||||
.addProvider(new RateLimitExceededExceptionMapper())
|
||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager))
|
||||
.addProvider(new DeviceLimitExceededExceptionMapper())
|
||||
@@ -122,6 +130,7 @@ class DeviceControllerTest {
|
||||
void setup() {
|
||||
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
|
||||
when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter);
|
||||
when(rateLimiters.getWaitForLinkedDeviceLimiter()).thenReturn(rateLimiter);
|
||||
|
||||
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
|
||||
@@ -479,16 +488,17 @@ class DeviceControllerTest {
|
||||
final Optional<ApnRegistrationId> apnRegistrationId,
|
||||
final Optional<GcmRegistrationId> gcmRegistrationId) {
|
||||
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
|
||||
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
|
||||
|
||||
final Device existingDevice = mock(Device.class);
|
||||
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
|
||||
|
||||
VerificationCode deviceCode = resources.getJerseyTest()
|
||||
final LinkDeviceToken deviceCode = resources.getJerseyTest()
|
||||
.target("/v1/devices/provisioning/code")
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.get(VerificationCode.class);
|
||||
.get(LinkDeviceToken.class);
|
||||
|
||||
final ECSignedPreKey aciSignedPreKey;
|
||||
final ECSignedPreKey pniSignedPreKey;
|
||||
@@ -506,7 +516,7 @@ class DeviceControllerTest {
|
||||
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
|
||||
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
|
||||
|
||||
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
|
||||
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(),
|
||||
new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null),
|
||||
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
|
||||
|
||||
@@ -539,21 +549,22 @@ class DeviceControllerTest {
|
||||
final KEMSignedPreKey pniPqLastResortPreKey) {
|
||||
|
||||
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
|
||||
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
|
||||
|
||||
final Device existingDevice = mock(Device.class);
|
||||
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
|
||||
|
||||
VerificationCode deviceCode = resources.getJerseyTest()
|
||||
final LinkDeviceToken deviceCode = resources.getJerseyTest()
|
||||
.target("/v1/devices/provisioning/code")
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.get(VerificationCode.class);
|
||||
.get(LinkDeviceToken.class);
|
||||
|
||||
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
|
||||
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
|
||||
|
||||
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
|
||||
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(),
|
||||
new AccountAttributes(true, 1234, 5678, null, null, true, null),
|
||||
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
|
||||
|
||||
@@ -931,4 +942,120 @@ class DeviceControllerTest {
|
||||
|
||||
verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
|
||||
}
|
||||
|
||||
@Test
|
||||
void waitForLinkedDevice() {
|
||||
final DeviceInfo deviceInfo = new DeviceInfo(Device.PRIMARY_ID,
|
||||
"Device name ciphertext".getBytes(StandardCharsets.UTF_8),
|
||||
System.currentTimeMillis(),
|
||||
System.currentTimeMillis());
|
||||
|
||||
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
|
||||
|
||||
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
|
||||
|
||||
try (final Response response = resources.getJerseyTest()
|
||||
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.get()) {
|
||||
|
||||
assertEquals(200, response.getStatus());
|
||||
|
||||
final DeviceInfo retrievedDeviceInfo = response.readEntity(DeviceInfo.class);
|
||||
assertEquals(deviceInfo.id(), retrievedDeviceInfo.id());
|
||||
assertArrayEquals(deviceInfo.name(), retrievedDeviceInfo.name());
|
||||
assertEquals(deviceInfo.created(), retrievedDeviceInfo.created());
|
||||
assertEquals(deviceInfo.lastSeen(), retrievedDeviceInfo.lastSeen());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void waitForLinkedDeviceNoDeviceLinked() {
|
||||
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
|
||||
|
||||
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
try (final Response response = resources.getJerseyTest()
|
||||
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.get()) {
|
||||
|
||||
assertEquals(204, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void waitForLinkedDeviceBadTokenIdentifier() {
|
||||
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
|
||||
|
||||
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
|
||||
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
|
||||
|
||||
try (final Response response = resources.getJerseyTest()
|
||||
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.get()) {
|
||||
|
||||
assertEquals(400, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) {
|
||||
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
|
||||
|
||||
try (final Response response = resources.getJerseyTest()
|
||||
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
|
||||
.queryParam("timeout", timeoutSeconds)
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.get()) {
|
||||
|
||||
assertEquals(400, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
private static List<Integer> waitForLinkedDeviceBadTimeout() {
|
||||
return List.of(0, -1, 3601);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void waitForLinkedDeviceBadTokenIdentifierLength(final String tokenIdentifier) {
|
||||
try (final Response response = resources.getJerseyTest()
|
||||
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.get()) {
|
||||
|
||||
assertEquals(400, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
private static List<String> waitForLinkedDeviceBadTokenIdentifierLength() {
|
||||
return List.of(RandomStringUtils.randomAlphanumeric(DeviceController.MIN_TOKEN_IDENTIFIER_LENGTH - 1),
|
||||
RandomStringUtils.randomAlphanumeric(DeviceController.MAX_TOKEN_IDENTIFIER_LENGTH + 1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void waitForLinkedDeviceRateLimited() throws RateLimitExceededException {
|
||||
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
|
||||
|
||||
doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID);
|
||||
|
||||
try (final Response response = resources.getJerseyTest()
|
||||
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.get()) {
|
||||
|
||||
assertEquals(429, response.getStatus());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user