diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index ad52a22d0..57f74d212 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -115,6 +115,8 @@ public class DeviceController { private static final String WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME = MetricsUtil.name(DeviceController.class, "waitForTransferArchiveDuration"); + private static final String RECORD_TRANSFER_ARCHIVE_UPLOADED_COUNTER_NAME = MetricsUtil.name(DeviceController.class, "recordTransferArchiveUploaded"); + private static final String HAS_REGISTRATION_ID_TAG_NAME = "hasRegistrationId"; @VisibleForTesting static final int MIN_TOKEN_IDENTIFIER_LENGTH = 32; @@ -533,8 +535,14 @@ public class DeviceController { @ApiResponse(responseCode = "422", description = "The request object could not be parsed or was otherwise invalid") @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") public CompletionStage recordTransferArchiveUploaded(@Auth final AuthenticatedDevice authenticatedDevice, - @NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) { - + @NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest, + @HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent) { + Metrics.counter(RECORD_TRANSFER_ARCHIVE_UPLOADED_COUNTER_NAME, Tags.of( + UserAgentTagUtil.getPlatformTag(userAgent), + io.micrometer.core.instrument.Tag.of( + HAS_REGISTRATION_ID_TAG_NAME, + String.valueOf(transferArchiveUploadedRequest.registrationId().isPresent())) + )).increment(); return rateLimiters.getUploadTransferArchiveLimiter() .validateAsync(authenticatedDevice.accountIdentifier()) .thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.accountIdentifier())) @@ -544,7 +552,8 @@ public class DeviceController { return accounts.recordTransferArchiveUpload(account, transferArchiveUploadedRequest.destinationDeviceId(), - Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()), + transferArchiveUploadedRequest.destinationDeviceCreated().map(Instant::ofEpochMilli), + transferArchiveUploadedRequest.registrationId(), transferArchiveUploadedRequest.transferArchive()); }); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/TransferArchiveUploadedRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/TransferArchiveUploadedRequest.java index f94331443..d1926fa9c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/TransferArchiveUploadedRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/TransferArchiveUploadedRequest.java @@ -7,21 +7,30 @@ package org.whispersystems.textsecuregcm.entities; import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.Valid; +import jakarta.validation.constraints.AssertTrue; import jakarta.validation.constraints.Max; import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Positive; import org.whispersystems.textsecuregcm.storage.Device; +import java.util.Optional; + public record TransferArchiveUploadedRequest( @Min(1) @Max(Device.MAXIMUM_DEVICE_ID) @Schema(description = "The ID of the device for which the transfer archive has been prepared") byte destinationDeviceId, - @Positive - @Schema(description = "The timestamp, in milliseconds since the epoch, at which the destination device was created") - long destinationDeviceCreated, + @Schema(description = """ + The timestamp, in milliseconds since the epoch, at which the destination device was created. + Deprecated in favor of registrationId. + """, deprecated = true) + @Deprecated + Optional<@Positive Long> destinationDeviceCreated, + + @Schema(description = "The registration ID of the destination device") + Optional<@Min(0) @Max(Device.MAX_REGISTRATION_ID) Integer> registrationId, @NotNull @Valid @@ -29,4 +38,10 @@ public record TransferArchiveUploadedRequest( The location of the transfer archive if the archive was successfully uploaded, otherwise a error indicating that the upload has failed and the destination device should stop waiting """, oneOf = {RemoteAttachment.class, RemoteAttachmentError.class}) - TransferArchiveResult transferArchive) {} + TransferArchiveResult transferArchive) { + @AssertTrue + @Schema(hidden = true) + public boolean isExactlyOneDisambiguatorProvided() { + return destinationDeviceCreated.isPresent() ^ registrationId.isPresent(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index c1ab19977..4cdedcd3b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -50,6 +50,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -85,6 +86,7 @@ import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryClient; import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.RegistrationIdValidator; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; import reactor.core.publisher.Flux; @@ -111,6 +113,8 @@ public class AccountsManager extends RedisPubSubAdapter implemen private static final String DELETE_COUNTER_NAME = name(AccountsManager.class, "deleteCounter"); private static final String COUNTRY_CODE_TAG_NAME = "country"; private static final String DELETION_REASON_TAG_NAME = "reason"; + private static final String TIMESTAMP_BASED_TRANSFER_ARCHIVE_KEY_COUNTER_NAME = name(AccountsManager.class, "timestampRedisKeyCounter"); + private static final String REGISTRATION_ID_BASED_TRANSFER_ARCHIVE_KEY_COUNTER_NAME = name(AccountsManager.class,"registrationIdRedisKeyCounter"); private static final Logger logger = LoggerFactory.getLogger(AccountsManager.class); @@ -140,7 +144,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen private final Map>> waitForDeviceFuturesByTokenIdentifier = new ConcurrentHashMap<>(); - private final Map>> waitForTransferArchiveFuturesByDeviceIdentifier = + private final Map>> waitForTransferArchiveFuturesByDeviceIdentifier = new ConcurrentHashMap<>(); private final Map>> waitForRestoreAccountRequestFuturesByToken = @@ -155,6 +159,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen private static final Duration RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL = Duration.ofHours(1); private static final String TRANSFER_ARCHIVE_PREFIX = "transfer_archive::"; private static final String TRANSFER_ARCHIVE_KEYSPACE_PATTERN = "__keyspace@0__:" + TRANSFER_ARCHIVE_PREFIX + "*"; + private static final String TRANSFER_ARCHIVE_REGISTRATION_ID_PATTERN = "registrationId"; private static final Duration RESTORE_ACCOUNT_REQUEST_TTL = Duration.ofHours(1); private static final String RESTORE_ACCOUNT_REQUEST_PREFIX = "restore_account::"; @@ -194,7 +199,14 @@ public class AccountsManager extends RedisPubSubAdapter implemen } } - private record TimestampedDeviceIdentifier(UUID accountIdentifier, byte deviceId, Instant deviceCreationTimestamp) { + private interface DeviceIdentifier {} + + private record TimestampDeviceIdentifier(UUID accountIdentifier, byte deviceId, Instant deviceCreationTimestamp) + implements DeviceIdentifier { + } + + private record RegistrationIdDeviceIdentifier(UUID accountIdentifier, byte deviceId, + int registrationId) implements DeviceIdentifier { } public AccountsManager(final Accounts accounts, @@ -1509,34 +1521,66 @@ public class AccountsManager extends RedisPubSubAdapter implemen } public CompletableFuture> waitForTransferArchive(final Account account, final Device device, final Duration timeout) { - final TimestampedDeviceIdentifier deviceIdentifier = - new TimestampedDeviceIdentifier(account.getIdentifier(IdentityType.ACI), - device.getId(), - Instant.ofEpochMilli(device.getCreated())); + final DeviceIdentifier timestampDeviceIdentifier = new TimestampDeviceIdentifier(account.getIdentifier(IdentityType.ACI), device.getId(), Instant.ofEpochMilli(device.getCreated())); + final String timestampTransferArchiveKey = getTimestampTransferArchiveKey(account.getIdentifier(IdentityType.ACI), device.getId(), Instant.ofEpochMilli(device.getCreated())); - return waitForPubSubKey(waitForTransferArchiveFuturesByDeviceIdentifier, - deviceIdentifier, - getTransferArchiveKey(account.getIdentifier(IdentityType.ACI), device.getId(), Instant.ofEpochMilli(device.getCreated())), + final DeviceIdentifier registrationIdDeviceIdentifier = new RegistrationIdDeviceIdentifier(account.getIdentifier(IdentityType.ACI), device.getId(), device.getRegistrationId(IdentityType.ACI)); + final String registrationIdTransferArchiveKey = getRegistrationIdTransferArchiveKey(account.getIdentifier(IdentityType.ACI), device.getId(), device.getRegistrationId(IdentityType.ACI)); + + final CompletableFuture> timestampFuture = waitForPubSubKey(waitForTransferArchiveFuturesByDeviceIdentifier, + timestampDeviceIdentifier, + timestampTransferArchiveKey, timeout, this::handleTransferArchiveAdded); + + final CompletableFuture> registrationIdFuture = waitForPubSubKey(waitForTransferArchiveFuturesByDeviceIdentifier, + registrationIdDeviceIdentifier, + registrationIdTransferArchiveKey, + timeout, + this::handleTransferArchiveAdded); + return firstSuccessfulTransferArchiveFuture(List.of(timestampFuture, registrationIdFuture)); + } + + @VisibleForTesting + static CompletableFuture> firstSuccessfulTransferArchiveFuture( + final List>> futures) { + final CompletableFuture> result = new CompletableFuture<>(); + final AtomicInteger remaining = new AtomicInteger(futures.size()); + + for (CompletableFuture> future : futures) { + future.whenComplete((value, _) -> { + if (value.isPresent()) { + result.complete(value); + } else if (remaining.decrementAndGet() == 0) { + result.complete(Optional.empty()); + } + }); + } + + return result; } public CompletableFuture recordTransferArchiveUpload(final Account account, final byte destinationDeviceId, - final Instant destinationDeviceCreationTimestamp, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional destinationDeviceCreationTimestamp, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional maybeRegistrationId, final TransferArchiveResult transferArchiveResult) { - - final String key = getTransferArchiveKey(account.getIdentifier(IdentityType.ACI), - destinationDeviceId, - destinationDeviceCreationTimestamp); - try { final String transferArchiveJson = SystemMapper.jsonMapper().writeValueAsString(transferArchiveResult); - return pubSubRedisClient.withConnection(connection -> - connection.async().set(key, transferArchiveJson, SetArgs.Builder.ex(RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL))) - .thenRun(Util.NOOP) - .toCompletableFuture(); + return pubSubRedisClient.withConnection(connection -> { + final String key = destinationDeviceCreationTimestamp + .map(timestamp -> getTimestampTransferArchiveKey(account.getIdentifier(IdentityType.ACI), destinationDeviceId, timestamp)) + .orElseGet(() -> maybeRegistrationId + .map(registrationId -> getRegistrationIdTransferArchiveKey(account.getIdentifier(IdentityType.ACI), destinationDeviceId, registrationId)) + // We validate the request object so this should never happen + .orElseThrow(() -> new AssertionError("No creation timestamp or registration ID provided"))); + + return connection.async() + .set(key, transferArchiveJson, SetArgs.Builder.ex(RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL)) + .thenRun(Util.NOOP) + .toCompletableFuture(); + }); } catch (final JsonProcessingException e) { // This should never happen for well-defined objects we control throw new UncheckedIOException(e); @@ -1552,15 +1596,27 @@ public class AccountsManager extends RedisPubSubAdapter implemen } } - private static String getTransferArchiveKey(final UUID accountIdentifier, + private static String getTimestampTransferArchiveKey(final UUID accountIdentifier, final byte destinationDeviceId, final Instant destinationDeviceCreationTimestamp) { + Metrics.counter(TIMESTAMP_BASED_TRANSFER_ARCHIVE_KEY_COUNTER_NAME).increment(); return TRANSFER_ARCHIVE_PREFIX + accountIdentifier.toString() + ":" + destinationDeviceId + ":" + destinationDeviceCreationTimestamp.toEpochMilli(); } + private static String getRegistrationIdTransferArchiveKey(final UUID accountIdentifier, + final byte destinationDeviceId, + final int registrationId) { + Metrics.counter(REGISTRATION_ID_BASED_TRANSFER_ARCHIVE_KEY_COUNTER_NAME).increment(); + + return TRANSFER_ARCHIVE_PREFIX + accountIdentifier.toString() + + ":" + destinationDeviceId + + ":" + TRANSFER_ARCHIVE_REGISTRATION_ID_PATTERN + + ":" + registrationId; + } + public CompletableFuture> waitForRestoreAccountRequest(final String token, final Duration timeout) { return waitForPubSubKey(waitForRestoreAccountRequestFuturesByToken, token, @@ -1648,23 +1704,36 @@ public class AccountsManager extends RedisPubSubAdapter implemen } else if (TRANSFER_ARCHIVE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) { // The `- 1` here compensates for the '*' in the pattern final String[] deviceIdentifierComponents = - channel.substring(TRANSFER_ARCHIVE_KEYSPACE_PATTERN.length() - 1).split(":", 3); + channel.substring(TRANSFER_ARCHIVE_KEYSPACE_PATTERN.length() - 1).split(":", 4); - if (deviceIdentifierComponents.length != 3) { - logger.error("Could not parse timestamped device identifier; unexpected component count"); + if (deviceIdentifierComponents.length != 3 && deviceIdentifierComponents.length != 4) { + logger.error("Could not parse device identifier; unexpected component count"); return; } + final DeviceIdentifier deviceIdentifier; + final String transferArchiveKey; try { - final TimestampedDeviceIdentifier deviceIdentifier; - final String transferArchiveKey; - { - final UUID accountIdentifier = UUID.fromString(deviceIdentifierComponents[0]); - final byte deviceId = Byte.parseByte(deviceIdentifierComponents[1]); + final UUID accountIdentifier = UUID.fromString(deviceIdentifierComponents[0]); + final byte deviceId = Byte.parseByte(deviceIdentifierComponents[1]); + + if (deviceIdentifierComponents.length == 3) { + // Parse the old transfer archive Redis key format final Instant deviceCreationTimestamp = Instant.ofEpochMilli(Long.parseLong(deviceIdentifierComponents[2])); - deviceIdentifier = new TimestampedDeviceIdentifier(accountIdentifier, deviceId, deviceCreationTimestamp); - transferArchiveKey = getTransferArchiveKey(accountIdentifier, deviceId, deviceCreationTimestamp); + deviceIdentifier = new TimestampDeviceIdentifier(accountIdentifier, deviceId, deviceCreationTimestamp); + transferArchiveKey = getTimestampTransferArchiveKey(accountIdentifier, deviceId, deviceCreationTimestamp); + } else { + final String maybeRegistrationIdPattern = deviceIdentifierComponents[2]; + if (!maybeRegistrationIdPattern.equals(TRANSFER_ARCHIVE_REGISTRATION_ID_PATTERN)) { + throw new IllegalArgumentException("Could not parse Redis key with pattern " + maybeRegistrationIdPattern); + } + final int registrationId = Integer.parseInt(deviceIdentifierComponents[3]); + if (!RegistrationIdValidator.validRegistrationId(registrationId)) { + throw new IllegalArgumentException("Invalid registration ID: " + registrationId); + } + deviceIdentifier = new RegistrationIdDeviceIdentifier(accountIdentifier, deviceId, registrationId); + transferArchiveKey = getRegistrationIdTransferArchiveKey(accountIdentifier, deviceId, registrationId); } Optional.ofNullable(waitForTransferArchiveFuturesByDeviceIdentifier.remove(deviceIdentifier)) @@ -1677,7 +1746,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen } })); } catch (final IllegalArgumentException e) { - logger.error("Could not parse timestamped device identifier", e); + logger.error("Could not parse device identifier", e); } } else if (RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN.equalsIgnoreCase(pattern) && "set".equalsIgnoreCase(message)) { // The `- 1` here compensates for the '*' in the pattern diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 44c5c2a6f..35928e075 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -1082,31 +1082,39 @@ class DeviceControllerTest { } } - @Test - void recordTransferArchiveUploaded() { + @ParameterizedTest + @MethodSource + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + void recordTransferArchiveUploaded(final Optional deviceCreated, final Optional registrationId) { final byte deviceId = Device.PRIMARY_ID + 1; - final Instant deviceCreated = Instant.now().truncatedTo(ChronoUnit.MILLIS); final RemoteAttachment transferArchive = new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); - when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive)) + when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, registrationId, transferArchive)) .thenReturn(CompletableFuture.completedFuture(null)); try (final Response response = resources.getJerseyTest() .target("/v1/devices/transfer_archive") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new TransferArchiveUploadedRequest(deviceId, deviceCreated.toEpochMilli(), transferArchive), + .put(Entity.entity(new TransferArchiveUploadedRequest(deviceId, deviceCreated.map(Instant::toEpochMilli), registrationId, transferArchive), MediaType.APPLICATION_JSON_TYPE))) { assertEquals(204, response.getStatus()); verify(accountsManager) - .recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive); + .recordTransferArchiveUpload(account, deviceId, deviceCreated, registrationId, transferArchive); } } + private static List recordTransferArchiveUploaded() { + return List.of( + Arguments.of(Optional.empty(), Optional.of(123)), + Arguments.of(Optional.of(Instant.now().truncatedTo(ChronoUnit.MILLIS)), Optional.empty()) + ); + } + @Test void recordTransferArchiveFailed() { final byte deviceId = Device.PRIMARY_ID + 1; @@ -1114,20 +1122,20 @@ class DeviceControllerTest { final RemoteAttachmentError transferFailure = new RemoteAttachmentError(RemoteAttachmentError.ErrorType.CONTINUE_WITHOUT_UPLOAD); when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); - when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure)) + when(accountsManager.recordTransferArchiveUpload(account, deviceId, Optional.of(deviceCreated), Optional.empty(), transferFailure)) .thenReturn(CompletableFuture.completedFuture(null)); try (final Response response = resources.getJerseyTest() .target("/v1/devices/transfer_archive") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new TransferArchiveUploadedRequest(deviceId, deviceCreated.toEpochMilli(), transferFailure), + .put(Entity.entity(new TransferArchiveUploadedRequest(deviceId, Optional.of(deviceCreated.toEpochMilli()), Optional.empty(), transferFailure), MediaType.APPLICATION_JSON_TYPE))) { assertEquals(204, response.getStatus()); verify(accountsManager) - .recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure); + .recordTransferArchiveUpload(account, deviceId, Optional.of(deviceCreated), Optional.empty(), transferFailure); } } @@ -1145,29 +1153,33 @@ class DeviceControllerTest { assertEquals(422, response.getStatus()); verify(accountsManager, never()) - .recordTransferArchiveUpload(any(), anyByte(), any(), any()); + .recordTransferArchiveUpload(any(), anyByte(), any(), any(), any()); } } @SuppressWarnings("DataFlowIssue") - private static List recordTransferArchiveUploadedBadRequest() { + private static List recordTransferArchiveUploadedBadRequest() { final RemoteAttachment validTransferArchive = new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("archive".getBytes(StandardCharsets.UTF_8))); return List.of( - // Invalid device ID - new TransferArchiveUploadedRequest((byte) -1, System.currentTimeMillis(), validTransferArchive), - - // Invalid "created at" timestamp - new TransferArchiveUploadedRequest(Device.PRIMARY_ID, -1, validTransferArchive), - - // Missing CDN number - new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(), - new RemoteAttachment(null, Base64.getUrlEncoder().encodeToString("archive".getBytes(StandardCharsets.UTF_8)))), - - // Bad attachment key - new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(), - new RemoteAttachment(3, "This is not a valid base64 string")) + Arguments.argumentSet("Invalid device ID", new TransferArchiveUploadedRequest((byte) -1, Optional.of(System.currentTimeMillis()), Optional.empty(), validTransferArchive)), + Arguments.argumentSet("Invalid \"created at\" timestamp", + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, Optional.of((long) -1), Optional.empty(), validTransferArchive)), + Arguments.argumentSet("Invalid registration ID - negative", + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, Optional.empty(), Optional.of(-1), validTransferArchive)), + Arguments.argumentSet("Invalid registration ID - too large", + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, Optional.empty(), Optional.of(0x4000), validTransferArchive)), + Arguments.argumentSet("Exactly one of \"created at\" timestamp and registration ID must be present - neither provided", + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, Optional.empty(), Optional.empty(), validTransferArchive)), + Arguments.argumentSet("Exactly one of \"created at\" timestamp and registration ID must be present - both provided", + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, Optional.of(System.currentTimeMillis()), Optional.of(123), validTransferArchive)), + Arguments.argumentSet("Missing CDN number", + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, Optional.of(System.currentTimeMillis()), Optional.empty(), + new RemoteAttachment(null, Base64.getUrlEncoder().encodeToString("archive".getBytes(StandardCharsets.UTF_8))))), + Arguments.argumentSet("Bad attachment key", + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, Optional.of(System.currentTimeMillis()), Optional.empty(), + new RemoteAttachment(3, "This is not a valid base64 string"))) ); } @@ -1180,14 +1192,14 @@ class DeviceControllerTest { .target("/v1/devices/transfer_archive") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(), + .put(Entity.entity(new TransferArchiveUploadedRequest(Device.PRIMARY_ID, Optional.of(System.currentTimeMillis()), Optional.empty(), new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)))), MediaType.APPLICATION_JSON_TYPE))) { assertEquals(429, response.getStatus()); verify(accountsManager, never()) - .recordTransferArchiveUpload(any(), anyByte(), any(), any()); + .recordTransferArchiveUpload(any(), anyByte(), any(), any(), any()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java index 18e5ee89d..6c6764be1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java @@ -11,6 +11,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.entities.RemoteAttachmentError; import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest; @@ -28,6 +31,7 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.Base64; +import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -87,18 +91,22 @@ public class AccountsManagerDeviceTransferIntegrationTest { accountsManager.stop(); } - @Test - void waitForTransferArchive() { + @ParameterizedTest + @MethodSource + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + void waitForTransferArchive( + final Optional recordUploadDeviceCreated, + final Optional recordUploadRegistrationId) { final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; - final long deviceCreated = System.currentTimeMillis(); final RemoteAttachment transferArchive = new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("transfer-archive".getBytes(StandardCharsets.UTF_8))); final Device device = mock(Device.class); when(device.getId()).thenReturn(deviceId); - when(device.getCreated()).thenReturn(deviceCreated); + when(device.getCreated()).thenReturn(recordUploadDeviceCreated.orElse(System.currentTimeMillis())); + when(device.getRegistrationId(IdentityType.ACI)).thenReturn(recordUploadRegistrationId.orElse(1)); final Account account = mock(Account.class); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); @@ -111,66 +119,106 @@ public class AccountsManagerDeviceTransferIntegrationTest { assertEquals(Optional.empty(), displacedFuture.join()); - accountsManager.recordTransferArchiveUpload(account, deviceId, Instant.ofEpochMilli(deviceCreated), transferArchive).join(); + accountsManager.recordTransferArchiveUpload(account, deviceId, recordUploadDeviceCreated.map(Instant::ofEpochMilli), recordUploadRegistrationId, transferArchive).join(); assertEquals(Optional.of(transferArchive), activeFuture.join()); } - @Test - void waitForTransferArchiveAlreadyAdded() { + private static List waitForTransferArchive() { + final long deviceCreated = System.currentTimeMillis(); + final int registrationId = 123; + + return List.of( + Arguments.of(Optional.empty(), Optional.of(registrationId)), + Arguments.of(Optional.of(deviceCreated), Optional.empty()) + ); + } + + @ParameterizedTest + @MethodSource + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + void waitForTransferArchiveAlreadyAdded( + final Optional recordUploadDeviceCreated, + final Optional recordUploadRegistrationId) { final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; - final long deviceCreated = System.currentTimeMillis(); + final RemoteAttachment transferArchive = new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("transfer-archive".getBytes(StandardCharsets.UTF_8))); final Device device = mock(Device.class); when(device.getId()).thenReturn(deviceId); - when(device.getCreated()).thenReturn(deviceCreated); + when(device.getCreated()).thenReturn(recordUploadDeviceCreated.orElse(System.currentTimeMillis())); + when(device.getRegistrationId(IdentityType.ACI)).thenReturn(recordUploadRegistrationId.orElse(1)); final Account account = mock(Account.class); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); - accountsManager.recordTransferArchiveUpload(account, deviceId, Instant.ofEpochMilli(deviceCreated), transferArchive).join(); + accountsManager.recordTransferArchiveUpload(account, deviceId, recordUploadDeviceCreated.map(Instant::ofEpochMilli), recordUploadRegistrationId, transferArchive).join(); assertEquals(Optional.of(transferArchive), accountsManager.waitForTransferArchive(account, device, Duration.ofSeconds(5)).join()); } - @Test - void waitForErrorTransferArchive() { + private static List waitForTransferArchiveAlreadyAdded() { + final long deviceCreated = System.currentTimeMillis(); + final int registrationId = 123; + + return List.of( + Arguments.of(Optional.empty(), Optional.of(registrationId)), + Arguments.of(Optional.of(deviceCreated), Optional.empty()) + ); + } + + @ParameterizedTest + @MethodSource + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + void waitForErrorTransferArchive( + final Optional recordUploadDeviceCreated, + final Optional recordUploadRegistrationId) { final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; - final long deviceCreated = System.currentTimeMillis(); final RemoteAttachmentError transferArchiveError = new RemoteAttachmentError(RemoteAttachmentError.ErrorType.CONTINUE_WITHOUT_UPLOAD); final Device device = mock(Device.class); when(device.getId()).thenReturn(deviceId); - when(device.getCreated()).thenReturn(deviceCreated); + when(device.getCreated()).thenReturn(recordUploadDeviceCreated.orElse(System.currentTimeMillis())); + when(device.getRegistrationId(IdentityType.ACI)).thenReturn(recordUploadRegistrationId.orElse(1)); final Account account = mock(Account.class); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); - accountsManager - .recordTransferArchiveUpload(account, deviceId, Instant.ofEpochMilli(deviceCreated), transferArchiveError) - .join(); + accountsManager.recordTransferArchiveUpload(account, deviceId, recordUploadDeviceCreated.map(Instant::ofEpochMilli), + recordUploadRegistrationId, transferArchiveError).join(); assertEquals(Optional.of(transferArchiveError), accountsManager.waitForTransferArchive(account, device, Duration.ofSeconds(5)).join()); } - @Test - void waitForTransferArchiveTimeout() { - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = Device.PRIMARY_ID; + private static List waitForErrorTransferArchive() { final long deviceCreated = System.currentTimeMillis(); + final int registrationId = 123; + + return List.of( + Arguments.of(Optional.empty(), Optional.of(registrationId)), + Arguments.of(Optional.of(deviceCreated), Optional.empty()) + ); + } + + @ParameterizedTest + @MethodSource + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + void waitForTransferArchiveTimeout( + final Optional recordUploadDeviceCreated, + final Optional recordUploadRegistrationId) { + final UUID accountIdentifier = UUID.randomUUID(); final Device device = mock(Device.class); - when(device.getId()).thenReturn(deviceId); - when(device.getCreated()).thenReturn(deviceCreated); + when(device.getCreated()).thenReturn(recordUploadDeviceCreated.orElse(System.currentTimeMillis())); + when(device.getRegistrationId(IdentityType.ACI)).thenReturn(recordUploadRegistrationId.orElse(1)); final Account account = mock(Account.class); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); @@ -179,6 +227,16 @@ public class AccountsManagerDeviceTransferIntegrationTest { accountsManager.waitForTransferArchive(account, device, Duration.ofMillis(1)).join()); } + private static List waitForTransferArchiveTimeout() { + final long deviceCreated = System.currentTimeMillis(); + final int registrationId = 123; + + return List.of( + Arguments.of(Optional.empty(), Optional.of(registrationId)), + Arguments.of(Optional.of(deviceCreated), Optional.empty()) + ); + } + @Test void waitForRestoreAccountRequest() { final String token = RandomStringUtils.secure().nextAlphanumeric(16); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index be7f8ef9f..9a49318e7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -57,6 +57,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Stream; @@ -82,6 +83,8 @@ import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.entities.RemoteAttachment; +import org.whispersystems.textsecuregcm.entities.TransferArchiveResult; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; @@ -1498,6 +1501,61 @@ class AccountsManagerTest { } } + @Test + void testFirstSuccessfulTransferArchiveCompletableFutureOneTimeout() { + // First future times out, second one completes successfully + final RemoteAttachment transferArchive = new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); + + final CompletableFuture> timeoutFuture = new CompletableFuture<>(); + timeoutFuture.completeOnTimeout(Optional.empty(), 50, TimeUnit.MILLISECONDS); + + final CompletableFuture> successfulFuture = new CompletableFuture<>(); + + final CompletableFuture> result = + AccountsManager.firstSuccessfulTransferArchiveFuture(List.of(timeoutFuture, successfulFuture)); + + CompletableFuture.delayedExecutor(100, TimeUnit.MILLISECONDS) + .execute(() -> successfulFuture.complete(Optional.of(transferArchive))); + + final Optional maybeTransferArchive = result.join(); + assertTrue(maybeTransferArchive.isPresent()); + assertEquals(transferArchive, maybeTransferArchive.get()); + } + + @Test + void testFirstSuccessfulTransferArchiveCompletableFutureBothTimeout() { + // Both futures time out + final CompletableFuture> firstTimeoutFuture = new CompletableFuture<>(); + firstTimeoutFuture.completeOnTimeout(Optional.empty(), 10, TimeUnit.MILLISECONDS); + + final CompletableFuture> secondTimeoutFuture = new CompletableFuture<>(); + secondTimeoutFuture.completeOnTimeout(Optional.empty(), 10, TimeUnit.MILLISECONDS); + + final CompletableFuture> result = + AccountsManager.firstSuccessfulTransferArchiveFuture(List.of(firstTimeoutFuture, secondTimeoutFuture)); + + assertTrue(result.join().isEmpty()); + } + + @Test + void testFirstSuccessfulTransferArchiveCompletableFuture() { + // First future completes successfully, second one times out + final RemoteAttachment transferArchive = new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); + + final CompletableFuture> successfulFuture = new CompletableFuture<>(); + + final CompletableFuture> timeoutFuture = new CompletableFuture<>(); + timeoutFuture.completeOnTimeout(Optional.empty(), 50, TimeUnit.MILLISECONDS); + + final CompletableFuture> result = + AccountsManager.firstSuccessfulTransferArchiveFuture(List.of(successfulFuture, timeoutFuture)); + successfulFuture.complete(Optional.of(transferArchive)); + + final Optional maybeTransferArchive = result.join(); + assertTrue(maybeTransferArchive.isPresent()); + assertEquals(transferArchive, maybeTransferArchive.get()); + } + private static List validateCompleteDeviceList() { final byte deviceId = Device.PRIMARY_ID; final byte extraDeviceId = deviceId + 1;