diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java index d7972402b..77270ad17 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java @@ -54,7 +54,6 @@ public class BackupAuthManager { final static Duration MAX_REDEMPTION_DURATION = Duration.ofDays(7); - final static String BACKUP_EXPERIMENT_NAME = "backup"; final static String BACKUP_MEDIA_EXPERIMENT_NAME = "backupMedia"; private final ExperimentEnrollmentManager experimentEnrollmentManager; @@ -99,9 +98,6 @@ public class BackupAuthManager { final Device device, final BackupAuthCredentialRequest messagesBackupCredentialRequest, final BackupAuthCredentialRequest mediaBackupCredentialRequest) { - if (configuredBackupLevel(account).isEmpty()) { - throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException(); - } if (!device.isPrimary()) { throw Status.PERMISSION_DENIED.withDescription("Only primary device can set backup-id").asRuntimeException(); } @@ -172,10 +168,6 @@ public class BackupAuthManager { }).thenCompose(updated -> getBackupAuthCredentials(updated, credentialType, redemptionStart, redemptionEnd)); } - // If this account isn't allowed some level of backup access via configuration, don't continue - final BackupLevel configuredBackupLevel = configuredBackupLevel(account).orElseThrow(() -> - Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException()); - final Instant startOfDay = clock.instant().truncatedTo(ChronoUnit.DAYS); if (redemptionStart.isAfter(redemptionEnd) || redemptionStart.isBefore(startOfDay) || @@ -191,6 +183,8 @@ public class BackupAuthManager { .orElseThrow(() -> Status.NOT_FOUND.withDescription("No blinded backup-id has been added to the account").asRuntimeException()); try { + final BackupLevel defaultBackupLevel = configuredBackupLevel(account); + // create a credential for every day in the requested period final BackupAuthCredentialRequest credentialReq = new BackupAuthCredentialRequest(committedBytes); return CompletableFuture.completedFuture(Stream @@ -198,7 +192,7 @@ public class BackupAuthManager { .map(redemptionTime -> { // Check if the account has a voucher that's good for a certain receiptLevel at redemption time, otherwise // use the default receipt level - final BackupLevel backupLevel = storedBackupLevel(account, redemptionTime).orElse(configuredBackupLevel); + final BackupLevel backupLevel = storedBackupLevel(account, redemptionTime).orElse(defaultBackupLevel); return new Credential( credentialReq.issueCredential(redemptionTime, backupLevel, credentialType, serverSecretParams), redemptionTime); @@ -328,20 +322,12 @@ public class BackupAuthManager { * Get the backup receipt level that should be used by default for this account determined via configuration. * * @param account the account to check - * @return If present, the default receipt level that should be used for the account if the account does not have a - * BackupVoucher. Empty if the account should never have backup access + * @return The default receipt level that should be used for the account if the account does not have a + * BackupVoucher. */ - private Optional configuredBackupLevel(final Account account) { - if (inExperiment(BACKUP_MEDIA_EXPERIMENT_NAME, account)) { - return Optional.of(BackupLevel.PAID); - } - if (inExperiment(BACKUP_EXPERIMENT_NAME, account)) { - return Optional.of(BackupLevel.FREE); - } - return Optional.empty(); - } - - private boolean inExperiment(final String experimentName, final Account account) { - return this.experimentEnrollmentManager.isEnrolled(account.getUuid(), experimentName); + private BackupLevel configuredBackupLevel(final Account account) { + return this.experimentEnrollmentManager.isEnrolled(account.getUuid(), BACKUP_MEDIA_EXPERIMENT_NAME) + ? BackupLevel.PAID + : BackupLevel.FREE; } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java index c4b44a9d4..91c27ddf7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java @@ -29,7 +29,6 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; import java.util.stream.Stream; -import javax.annotation.Nullable; import org.assertj.core.api.Assertions; import org.assertj.core.api.ThrowableAssert; import org.junit.jupiter.api.BeforeEach; @@ -38,7 +37,6 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; -import org.junit.jupiter.params.provider.NullSource; import org.junit.jupiter.params.provider.ValueSource; import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.mockito.ArgumentCaptor; @@ -57,6 +55,7 @@ import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialResponse; import org.signal.libsignal.zkgroup.receipts.ReceiptSerial; import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.storage.Account; @@ -86,13 +85,16 @@ public class BackupAuthManagerTest { reset(redeemedReceiptsManager); } - BackupAuthManager create(@Nullable BackupLevel backupLevel) { - return create(backupLevel, rateLimiter(aci, false, false)); + BackupAuthManager create() { + return create(BackupLevel.FREE, rateLimiter(aci, false, false)); } - BackupAuthManager create(@Nullable BackupLevel backupLevel, RateLimiters rateLimiters) { + BackupAuthManager create(BackupLevel defaultBackupLevel, RateLimiters rateLimiters) { return new BackupAuthManager( - ExperimentHelper.withEnrollment(experimentName(backupLevel), aci), + switch (defaultBackupLevel) { + case FREE -> mock(ExperimentEnrollmentManager.class); + case PAID -> ExperimentHelper.withEnrollment(BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME, aci); + }, rateLimiters, accountsManager, new ServerZkReceiptOperations(receiptParams), @@ -103,7 +105,7 @@ public class BackupAuthManagerTest { @Test void commitBackupId() { - final BackupAuthManager authManager = create(BackupLevel.FREE); + final BackupAuthManager authManager = create(); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -122,16 +124,15 @@ public class BackupAuthManagerTest { authManager.commitBackupId(account, primaryDevice(), messagesCredentialRequest, mediaCredentialRequest).join(); - verify(account).setBackupCredentialRequests(messagesCredentialRequest.serialize(), mediaCredentialRequest.serialize()); + verify(account).setBackupCredentialRequests(messagesCredentialRequest.serialize(), + mediaCredentialRequest.serialize()); } @ParameterizedTest @EnumSource - @NullSource - void commitRequiresBackupLevel(final BackupLevel backupLevel) { - final BackupAuthManager authManager = create(backupLevel); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); + void commitOnAnyBackupLevel(final BackupLevel backupLevel) { + final BackupAuthManager authManager = create(); + final Account account = new MockAccountBuilder().backupLevel(backupLevel).build(); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); final ThrowableAssert.ThrowingCallable commit = () -> @@ -139,21 +140,13 @@ public class BackupAuthManagerTest { primaryDevice(), backupAuthTestUtil.getRequest(messagesBackupKey, aci), backupAuthTestUtil.getRequest(mediaBackupKey, aci)).join(); - if (backupLevel == null) { - assertThatExceptionOfType(StatusRuntimeException.class) - .isThrownBy(commit) - .extracting(ex -> ex.getStatus().getCode()) - .isEqualTo(Status.Code.PERMISSION_DENIED); - } else { - Assertions.assertThatNoException().isThrownBy(commit); - } + Assertions.assertThatNoException().isThrownBy(commit); } @Test void commitRequiresPrimary() { - final BackupAuthManager authManager = create(BackupLevel.FREE); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); + final BackupAuthManager authManager = create(); + final Account account = new MockAccountBuilder().build(); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); final ThrowableAssert.ThrowingCallable commit = () -> @@ -167,18 +160,47 @@ public class BackupAuthManagerTest { .isEqualTo(Status.Code.PERMISSION_DENIED); } + @CartesianTest + void paidTierCredentialViaConfiguration(@CartesianTest.Enum final BackupCredentialType credentialType) + throws VerificationFailedException { + final BackupAuthManager authManager = create(BackupLevel.PAID, rateLimiter(aci, false, false)); + + final byte[] backupKey = switch (credentialType) { + case MESSAGES -> messagesBackupKey; + case MEDIA -> mediaBackupKey; + }; + + // Account does not have PAID tier set + final Account account = new MockAccountBuilder() + .messagesCredential(backupAuthTestUtil.getRequest(messagesBackupKey, aci)) + .mediaCredential(backupAuthTestUtil.getRequest(mediaBackupKey, aci)) + .build(); + + final BackupAuthCredentialRequestContext requestContext = + BackupAuthCredentialRequestContext.create(backupKey, aci); + + final Instant start = clock.instant().truncatedTo(ChronoUnit.DAYS); + final List creds = authManager.getBackupAuthCredentials(account, + credentialType, start, start.plus(Duration.ofDays(1))).join(); + + assertThat(creds).hasSize(2); + assertThat(requestContext + .receiveResponse(creds.getFirst().credential(), start, backupAuthTestUtil.params.getPublicParams()) + .getBackupLevel()) + .isEqualTo(BackupLevel.PAID); + } + @CartesianTest void getBackupAuthCredentials(@CartesianTest.Enum final BackupLevel backupLevel, @CartesianTest.Enum final BackupCredentialType credentialType) { - final BackupAuthManager authManager = create(backupLevel); + final BackupAuthManager authManager = create(); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); + final Account account = new MockAccountBuilder() + .backupLevel(backupLevel) + .messagesCredential(backupAuthTestUtil.getRequest(messagesBackupKey, aci)) + .mediaCredential(backupAuthTestUtil.getRequest(mediaBackupKey, aci)) + .build(); assertThat(authManager.getBackupAuthCredentials(account, credentialType, @@ -189,15 +211,10 @@ public class BackupAuthManagerTest { @ParameterizedTest @EnumSource - void getBackupAuthCredentialsNoBackupLevel(final BackupCredentialType credentialType) { - final BackupAuthManager authManager = create(null); + void getBackupAuthCredentialsNoCommittedId(final BackupCredentialType credentialType) { + final BackupAuthManager authManager = create(); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); + final Account account = new MockAccountBuilder().build(); assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.getBackupAuthCredentials(account, @@ -205,13 +222,13 @@ public class BackupAuthManagerTest { clock.instant().truncatedTo(ChronoUnit.DAYS), clock.instant().plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS)).join()) .extracting(ex -> ex.getStatus().getCode()) - .isEqualTo(Status.Code.PERMISSION_DENIED); + .isEqualTo(Status.Code.NOT_FOUND); } @CartesianTest void getReceiptCredentials(@CartesianTest.Enum final BackupLevel backupLevel, @CartesianTest.Enum final BackupCredentialType credentialType) throws VerificationFailedException { - final BackupAuthManager authManager = create(backupLevel); + final BackupAuthManager authManager = create(); final byte[] backupKey = switch (credentialType) { case MESSAGES -> messagesBackupKey; @@ -221,12 +238,11 @@ public class BackupAuthManagerTest { final BackupAuthCredentialRequestContext requestContext = BackupAuthCredentialRequestContext.create(backupKey, aci); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); + final Account account = new MockAccountBuilder() + .backupLevel(backupLevel) + .mediaCredential(backupAuthTestUtil.getRequest(mediaBackupKey, aci)) + .messagesCredential(backupAuthTestUtil.getRequest(messagesBackupKey, aci)) + .build(); final Instant start = clock.instant().truncatedTo(ChronoUnit.DAYS); final List creds = authManager.getBackupAuthCredentials(account, @@ -268,14 +284,12 @@ public class BackupAuthManagerTest { @MethodSource void invalidCredentialTimeWindows(final Instant requestRedemptionStart, final Instant requestRedemptionEnd, final Instant now) { - final BackupAuthManager authManager = create(BackupLevel.FREE); + final BackupAuthManager authManager = create(); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); + final Account account = new MockAccountBuilder() + .messagesCredential(backupAuthTestUtil.getRequest(messagesBackupKey, aci)) + .mediaCredential(backupAuthTestUtil.getRequest(mediaBackupKey, aci)) + .build(); clock.pin(now); assertThatExceptionOfType(StatusRuntimeException.class) @@ -292,15 +306,13 @@ public class BackupAuthManagerTest { final Instant day4 = Instant.EPOCH.plus(Duration.ofDays(4)); final Instant dayMax = day0.plus(BackupAuthManager.MAX_REDEMPTION_DURATION); - final BackupAuthManager authManager = create(BackupLevel.FREE); + final BackupAuthManager authManager = create(); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); - when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(201, day4)); + final Account account = new MockAccountBuilder() + .messagesCredential(backupAuthTestUtil.getRequest(messagesBackupKey, aci)) + .mediaCredential(backupAuthTestUtil.getRequest(mediaBackupKey, aci)) + .backupVoucher(new Account.BackupVoucher(201, day4)) + .build(); final List creds = authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, day0, dayMax).join(); Instant redemptionTime = day0; @@ -325,19 +337,19 @@ public class BackupAuthManagerTest { final Instant day2 = Instant.EPOCH.plus(Duration.ofDays(2)); final Instant day3 = Instant.EPOCH.plus(Duration.ofDays(3)); - final BackupAuthManager authManager = create(BackupLevel.FREE); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(3, day1)); + final BackupAuthManager authManager = create(); + final Account account = new MockAccountBuilder() + .messagesCredential(backupAuthTestUtil.getRequest(messagesBackupKey, aci)) + .mediaCredential(backupAuthTestUtil.getRequest(mediaBackupKey, aci)) + .backupVoucher(new Account.BackupVoucher(3, day1)) + .build(); - final Account updated = mock(Account.class); - when(updated.getUuid()).thenReturn(aci); - when(updated.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); - when(updated.getBackupCredentialRequest(BackupCredentialType.MEDIA)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); + final Account updated = new MockAccountBuilder() + .messagesCredential(backupAuthTestUtil.getRequest(messagesBackupKey, aci)) + .mediaCredential(backupAuthTestUtil.getRequest(mediaBackupKey, aci)) + .backupVoucher(null) + .build(); - when(updated.getBackupVoucher()).thenReturn(null); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(updated)); clock.pin(day2.plus(Duration.ofSeconds(1))); @@ -365,11 +377,10 @@ public class BackupAuthManagerTest { @Test void redeemReceipt() throws InvalidInputException, VerificationFailedException { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); - final BackupAuthManager authManager = create(BackupLevel.FREE); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)).thenReturn(Optional.of(new byte[0])); - + final BackupAuthManager authManager = create(); + final Account account = new MockAccountBuilder() + .mediaCredential(Optional.of(new byte[0])) + .build(); clock.pin(Instant.EPOCH.plus(Duration.ofDays(1))); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); when(redeemedReceiptsManager.put(any(), eq(expirationTime.getEpochSecond()), eq(201L), eq(aci))) @@ -381,10 +392,8 @@ public class BackupAuthManagerTest { @Test void redeemReceiptNoBackupRequest() { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); - final BackupAuthManager authManager = create(BackupLevel.FREE); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)).thenReturn(Optional.empty()); + final BackupAuthManager authManager = create(); + final Account account = new MockAccountBuilder().mediaCredential(Optional.empty()).build(); clock.pin(Instant.EPOCH.plus(Duration.ofDays(1))); when(redeemedReceiptsManager.put(any(), eq(expirationTime.getEpochSecond()), eq(201L), eq(aci))) @@ -401,13 +410,12 @@ public class BackupAuthManagerTest { final Instant newExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); final Instant existingExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1)).plus(Duration.ofSeconds(1)); - final BackupAuthManager authManager = create(BackupLevel.FREE); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)).thenReturn(Optional.of(new byte[0])); - - // The account has an existing voucher with a later expiration date - when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(201, existingExpirationTime)); + final BackupAuthManager authManager = create(); + final Account account = new MockAccountBuilder() + .mediaCredential(Optional.of(new byte[0])) + // The account has an existing voucher with a later expiration date + .backupVoucher(new Account.BackupVoucher(201, existingExpirationTime)) + .build(); clock.pin(Instant.EPOCH.plus(Duration.ofDays(1))); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); @@ -427,7 +435,7 @@ public class BackupAuthManagerTest { void redeemExpiredReceipt() { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); clock.pin(expirationTime.plus(Duration.ofSeconds(1))); - final BackupAuthManager authManager = create(BackupLevel.FREE); + final BackupAuthManager authManager = create(); assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), receiptPresentation(3, expirationTime)).join()) .extracting(ex -> ex.getStatus().getCode()) @@ -441,7 +449,7 @@ public class BackupAuthManagerTest { void redeemInvalidLevel(long level) { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); clock.pin(expirationTime.plus(Duration.ofSeconds(1))); - final BackupAuthManager authManager = create(BackupLevel.FREE); + final BackupAuthManager authManager = create(); assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), receiptPresentation(level, expirationTime)).join()) @@ -453,7 +461,7 @@ public class BackupAuthManagerTest { @Test void redeemInvalidPresentation() throws InvalidInputException, VerificationFailedException { - final BackupAuthManager authManager = create(BackupLevel.FREE); + final BackupAuthManager authManager = create(); final ReceiptCredentialPresentation invalid = receiptPresentation(ServerSecretParams.generate(), 3L, Instant.EPOCH); assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), invalid).join()) @@ -466,10 +474,10 @@ public class BackupAuthManagerTest { @Test void receiptAlreadyRedeemed() throws InvalidInputException, VerificationFailedException { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); - final BackupAuthManager authManager = create(BackupLevel.FREE); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)).thenReturn(Optional.of(new byte[0])); + final BackupAuthManager authManager = create(); + final Account account = new MockAccountBuilder() + .mediaCredential(Optional.of(new byte[0])) + .build(); clock.pin(Instant.EPOCH.plus(Duration.ofDays(1))); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); @@ -513,7 +521,12 @@ public class BackupAuthManagerTest { final BackupAuthManager authManager = create(BackupLevel.FREE, rateLimiter(aci, rateLimitBackupId, false)); final BackupAuthCredentialRequest storedMessagesCredential = backupAuthTestUtil.getRequest(messagesBackupKey, aci); final BackupAuthCredentialRequest storedMediaCredential = backupAuthTestUtil.getRequest(mediaBackupKey, aci); - final Account account = mockAccount(storedMessagesCredential, storedMediaCredential, null); + final Account account = new MockAccountBuilder() + .mediaCredential(storedMediaCredential) + .messagesCredential(storedMessagesCredential) + .backupVoucher(null) + .build(); + when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); final BackupAuthCredentialRequest newMessagesCredential = changeMessage ? backupAuthTestUtil.getRequest(TestRandomUtil.nextBytes(32), aci) @@ -546,7 +559,12 @@ public class BackupAuthManagerTest { final Account.BackupVoucher backupVoucher = new Account.BackupVoucher(1, Instant.ofEpochSecond(100)); clock.pin(paid ? Instant.ofEpochSecond(99) : Instant.ofEpochSecond(101)); - final Account account = mockAccount(storedMessagesCredential, storedMediaCredential, backupVoucher); + final Account account = new MockAccountBuilder() + .mediaCredential(storedMediaCredential) + .messagesCredential(storedMessagesCredential) + .backupVoucher(backupVoucher) + .build(); + when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); final BackupAuthCredentialRequest newMessagesCredential = changeMessage ? backupAuthTestUtil.getRequest(TestRandomUtil.nextBytes(32), aci) @@ -566,22 +584,6 @@ public class BackupAuthManagerTest { } } - private Account mockAccount(final BackupAuthCredentialRequest storedMessagesCredential, final BackupAuthCredentialRequest storedMediaCredential, Account.BackupVoucher backupVoucher) { - final Account account = mock(Account.class); - when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); - if (storedMessagesCredential != null) { - when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) - .thenReturn(Optional.of(storedMessagesCredential.serialize())); - } - if (storedMediaCredential != null) { - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) - .thenReturn(Optional.of(storedMediaCredential.serialize())); - } - when(account.getUuid()).thenReturn(aci); - when(account.getBackupVoucher()).thenReturn(backupVoucher); - return account; - } - private Device primaryDevice() { final Device device = mock(Device.class); when(device.isPrimary()).thenReturn(true); @@ -594,14 +596,48 @@ public class BackupAuthManagerTest { return device; } - private static String experimentName(@Nullable BackupLevel backupLevel) { - return switch (backupLevel) { - case FREE -> BackupAuthManager.BACKUP_EXPERIMENT_NAME; - case PAID -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME; - case null -> "fake_experiment"; - }; + private class MockAccountBuilder { + + private Account account = mock(Account.class); + + MockAccountBuilder() { + when(account.getUuid()).thenReturn(aci); + } + + MockAccountBuilder backupLevel(BackupLevel backupLevel) { + if (backupLevel == BackupLevel.PAID) { + return backupVoucher(new Account.BackupVoucher(201L, clock.instant().plus(Duration.ofDays(8)))); + } + return this; + } + + MockAccountBuilder backupVoucher(Account.BackupVoucher backupVoucher) { + when(account.getBackupVoucher()).thenReturn(backupVoucher); + return this; + } + + MockAccountBuilder mediaCredential(final BackupAuthCredentialRequest storedMediaCredential) { + return mediaCredential(Optional.of(storedMediaCredential.serialize())); + } + + MockAccountBuilder mediaCredential(final Optional serializedMediaCredential) { + when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) + .thenReturn(serializedMediaCredential); + return this; + } + + MockAccountBuilder messagesCredential(final BackupAuthCredentialRequest storedMessagesCredential) { + when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) + .thenReturn(Optional.of(storedMessagesCredential.serialize())); + return this; + } + + Account build() { + return account; + } } + private static RateLimiters rateLimiter(final UUID aci, boolean rateLimitBackupId, boolean rateLimitPaidMediaBackupId) { final RateLimiters limiters = mock(RateLimiters.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthTestUtil.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthTestUtil.java index d0029bc83..070cd05a6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthTestUtil.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthTestUtil.java @@ -21,8 +21,8 @@ import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext; import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper; public class BackupAuthTestUtil { @@ -64,15 +64,15 @@ public class BackupAuthTestUtil { final Instant redemptionEnd) { final UUID aci = UUID.randomUUID(); - final String experimentName = switch (backupLevel) { - case FREE -> BackupAuthManager.BACKUP_EXPERIMENT_NAME; - case PAID -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME; - }; final BackupAuthManager issuer = new BackupAuthManager( - ExperimentHelper.withEnrollment(experimentName, aci), null, null, null, null, params, clock); + mock(ExperimentEnrollmentManager.class), null, null, null, null, params, clock); Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); when(account.getBackupCredentialRequest(credentialType)).thenReturn(Optional.of(request.serialize())); + when(account.getBackupVoucher()).thenReturn(switch (backupLevel) { + case FREE -> null; + case PAID -> new Account.BackupVoucher(201L, redemptionEnd.plus(1, ChronoUnit.SECONDS)); + }); return issuer.getBackupAuthCredentials(account, credentialType, redemptionStart, redemptionEnd).join(); } }