Add API endpoints for waiting for newly-linked devices

This commit is contained in:
Jon Chambers
2024-10-10 10:11:32 -04:00
committed by GitHub
parent 087c2b61ee
commit 8c30a359e7
16 changed files with 793 additions and 122 deletions

View File

@@ -4,22 +4,36 @@
*/
package org.whispersystems.textsecuregcm.controllers;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.lettuce.core.RedisException;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.LinkedList;
import java.time.Duration;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
@@ -27,10 +41,12 @@ import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@@ -47,6 +63,8 @@ import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
@@ -54,7 +72,11 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
@@ -69,6 +91,21 @@ public class DeviceController {
private final RateLimiters rateLimiters;
private final Map<String, Integer> maxDeviceConfiguration;
private final EnumMap<ClientPlatform, AtomicInteger> linkedDeviceListenersByPlatform;
private final AtomicInteger linkedDeviceListenersForUnrecognizedPlatforms;
private static final String LINKED_DEVICE_LISTENER_GAUGE_NAME =
MetricsUtil.name(DeviceController.class, "linkedDeviceListeners");
private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAME =
MetricsUtil.name(DeviceController.class, "waitForLinkedDeviceDuration");
@VisibleForTesting
static final int MIN_TOKEN_IDENTIFIER_LENGTH = 32;
@VisibleForTesting
static final int MAX_TOKEN_IDENTIFIER_LENGTH = 64;
public DeviceController(final AccountsManager accounts,
final ClientPublicKeysManager clientPublicKeysManager,
final RateLimiters rateLimiters,
@@ -78,19 +115,32 @@ public class DeviceController {
this.clientPublicKeysManager = clientPublicKeysManager;
this.rateLimiters = rateLimiters;
this.maxDeviceConfiguration = maxDeviceConfiguration;
linkedDeviceListenersByPlatform = Arrays.stream(ClientPlatform.values())
.collect(Collectors.toMap(
Function.identity(),
clientPlatform -> buildGauge(clientPlatform.name().toLowerCase()),
(a, b) -> {
throw new AssertionError("Duplicate client platform enumeration key");
},
() -> new EnumMap<>(ClientPlatform.class)
));
linkedDeviceListenersForUnrecognizedPlatforms = buildGauge("unknown");
}
private static AtomicInteger buildGauge(final String clientPlatformName) {
return Metrics.gauge(LINKED_DEVICE_LISTENER_GAUGE_NAME,
Tags.of(io.micrometer.core.instrument.Tag.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatformName)),
new AtomicInteger(0));
}
@GET
@Produces(MediaType.APPLICATION_JSON)
public DeviceInfoList getDevices(@ReadOnly @Auth AuthenticatedDevice auth) {
List<DeviceInfo> devices = new LinkedList<>();
for (Device device : auth.getAccount().getDevices()) {
devices.add(new DeviceInfo(device.getId(), device.getName(),
device.getLastSeen(), device.getCreated()));
}
return new DeviceInfoList(devices);
return new DeviceInfoList(auth.getAccount().getDevices().stream()
.map(DeviceInfo::forDevice)
.toList());
}
@DELETE
@@ -138,7 +188,7 @@ public class DeviceController {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public VerificationCode createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth)
public LinkDeviceToken createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth)
throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = auth.getAccount();
@@ -159,7 +209,9 @@ public class DeviceController {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
return new VerificationCode(accounts.generateDeviceLinkingToken(account.getUuid()));
final String token = accounts.generateLinkDeviceToken(account.getUuid());
return new LinkDeviceToken(token, AccountsManager.getLinkDeviceTokenIdentifier(token));
}
@PUT
@@ -266,6 +318,83 @@ public class DeviceController {
}
}
@GET
@Path("/wait_for_linked_device/{tokenIdentifier}")
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Wait for a new device to be linked to an account",
description = """
Waits for a new device to be linked to an account and returns basic information about the new device when
available.
""")
@ApiResponse(responseCode = "200", description = "The specified was linked to an account")
@ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed")
@ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
@Schema(description = "Basic information about the linked device", implementation = DeviceInfo.class)
public CompletableFuture<Response> waitForLinkedDevice(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@PathParam("tokenIdentifier")
@Schema(description = "A 'link device' token identifier provided by the 'create link device token' endpoint")
@Size(min = MIN_TOKEN_IDENTIFIER_LENGTH, max = MAX_TOKEN_IDENTIFIER_LENGTH)
final String tokenIdentifier,
@QueryParam("timeout")
@DefaultValue("30")
@Min(1)
@Max(3600)
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED,
minimum = "1",
maximum = "3600",
description = """
The amount of time (in seconds) to wait for a response. If the expected device is not linked within the
given amount of time, this endpoint will return a status of HTTP/204.
""") final int timeoutSeconds,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException {
rateLimiters.getWaitForLinkedDeviceLimiter().validate(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI));
final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent);
linkedDeviceListenerCounter.incrementAndGet();
final Timer.Sample sample = Timer.start();
try {
return accounts.waitForNewLinkedDevice(tokenIdentifier, Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
.map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class,
e -> Response.status(Response.Status.BAD_REQUEST).build()))
.whenComplete((response, throwable) -> {
linkedDeviceListenerCounter.decrementAndGet();
if (response != null) {
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
io.micrometer.core.instrument.Tag.of("deviceFound",
String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode()))))
.register(Metrics.globalRegistry));
}
});
} catch (final RedisException e) {
// `waitForNewLinkedDevice` could fail synchronously if the Redis circuit breaker is open; prevent counter drift
// if that happens
linkedDeviceListenerCounter.decrementAndGet();
throw e;
}
}
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
try {
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform());
} catch (final UnrecognizedUserAgentException ignored) {
return linkedDeviceListenersForUnrecognizedPlatforms;
}
}
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/unauthenticated_delivery")