Forgive some clock skew when requesting ZK credentials

This commit is contained in:
ravi-signal
2025-10-01 13:03:27 -05:00
committed by GitHub
parent 70ac4ad139
commit 9384813752
10 changed files with 273 additions and 117 deletions

View File

@@ -0,0 +1,95 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.StreamSupport;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.util.TestClock;
class RedemptionRangeTest {
static List<Arguments> invalidCredentialTimeWindows() {
final Duration max = RedemptionRange.MAX_REDEMPTION_DURATION;
final Instant day0 = Instant.EPOCH;
final Instant day1 = Instant.EPOCH.plus(Duration.ofDays(1));
final Instant day2 = Instant.EPOCH.plus(Duration.ofDays(2));
return List.of(
Arguments.argumentSet("non-truncated start", Instant.ofEpochSecond(100), day0.plus(max),
Instant.ofEpochSecond(100)),
Arguments.argumentSet("non-truncated end", day0, Instant.ofEpochSecond(1).plus(max),
Instant.ofEpochSecond(100)),
Arguments.argumentSet("start too old", day0, day0.plus(max), day2),
Arguments.argumentSet("end too far in the future", day2, day2.plus(max), day0),
Arguments.argumentSet("end before start", day1, day0, day1),
Arguments.argumentSet("window too big", day0, day0.plus(max).plus(Duration.ofDays(1)), day1)
);
}
@ParameterizedTest
@MethodSource
void invalidCredentialTimeWindows(final Instant requestRedemptionStart, final Instant requestRedemptionEnd,
final Instant now) {
final Clock clock = TestClock.pinned(now);
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> RedemptionRange.inclusive(clock, requestRedemptionStart, requestRedemptionEnd));
}
@Test
void allowUpToMax() {
final Instant now = Instant.EPOCH.plus(Duration.ofDays(1)).plus(Duration.ofSeconds(100));
final Instant today = now.truncatedTo(ChronoUnit.DAYS);
final Clock clock = TestClock.pinned(now);
for (Duration d = Duration.ofDays(0);
d.compareTo(RedemptionRange.MAX_REDEMPTION_DURATION) <= 0;
d = d.plus(Duration.ofDays(1))) {
final Duration fd = d;
assertThatNoException().isThrownBy(() -> RedemptionRange.inclusive(clock, today, today.plus(fd)));
}
}
@Test
void allowBackwardsSkew() {
final Instant now = Instant.EPOCH.plus(Duration.ofDays(1)).plus(Duration.ofSeconds(100));
final Instant yesterday = now.minus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS);
final Clock clock = TestClock.pinned(now);
assertThatNoException().isThrownBy(() ->
RedemptionRange.inclusive(clock, yesterday, yesterday.plus(RedemptionRange.MAX_REDEMPTION_DURATION)));
}
@Test
void allowForwardsSkew() {
final Instant now = Instant.EPOCH.plus(Duration.ofDays(1)).plus(Duration.ofSeconds(100));
final Instant tomorrow = now.plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS);
final Clock clock = TestClock.pinned(now);
assertThatNoException().isThrownBy(() ->
RedemptionRange.inclusive(clock, tomorrow, tomorrow.plus(RedemptionRange.MAX_REDEMPTION_DURATION)));
}
@Test
void inclusiveRange() {
final Instant now = Instant.EPOCH.plus(Duration.ofDays(1)).plus(Duration.ofSeconds(100));
final Instant today = now.truncatedTo(ChronoUnit.DAYS);
final Clock clock = TestClock.pinned(now);
for (int numDays = 0; numDays < 7; numDays++) {
final RedemptionRange range = RedemptionRange.inclusive(clock, today, today.plus(Duration.ofDays(numDays)));
final List<Instant> instants = StreamSupport.stream(range.spliterator(), false).toList();
assertThat(instants.size()).isEqualTo(numDays + 1);
}
}
}

View File

@@ -29,15 +29,12 @@ import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.ThrowableAssert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor;
@@ -55,6 +52,7 @@ import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequestContext;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialResponse;
import org.signal.libsignal.zkgroup.receipts.ReceiptSerial;
import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.whispersystems.textsecuregcm.auth.RedemptionRange;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
@@ -71,18 +69,20 @@ import org.whispersystems.textsecuregcm.util.TestRandomUtil;
public class BackupAuthManagerTest {
private static final Instant NOW = Instant.now();
private final UUID aci = UUID.randomUUID();
private final byte[] messagesBackupKey = TestRandomUtil.nextBytes(32);
private final byte[] mediaBackupKey = TestRandomUtil.nextBytes(32);
private final ServerSecretParams receiptParams = ServerSecretParams.generate();
private final TestClock clock = TestClock.now();
private final TestClock clock = TestClock.pinned(NOW);
private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(clock);
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final RedeemedReceiptsManager redeemedReceiptsManager = mock(RedeemedReceiptsManager.class);
@BeforeEach
void setUp() {
clock.unpin();
clock.pin(NOW);
reset(accountsManager);
reset(redeemedReceiptsManager);
}
@@ -181,13 +181,13 @@ public class BackupAuthManagerTest {
final BackupAuthCredentialRequestContext requestContext =
BackupAuthCredentialRequestContext.create(backupKey, aci);
final Instant start = clock.instant().truncatedTo(ChronoUnit.DAYS);
final List<BackupAuthManager.Credential> creds = authManager.getBackupAuthCredentials(account,
credentialType, start, start.plus(Duration.ofDays(1))).join();
final RedemptionRange range = range(Duration.ofDays(1));
final List<BackupAuthManager.Credential> creds =
authManager.getBackupAuthCredentials(account, credentialType, range(Duration.ofDays(1))).join();
assertThat(creds).hasSize(2);
assertThat(requestContext
.receiveResponse(creds.getFirst().credential(), start, backupAuthTestUtil.params.getPublicParams())
.receiveResponse(creds.getFirst().credential(), range.iterator().next(), backupAuthTestUtil.params.getPublicParams())
.getBackupLevel())
.isEqualTo(BackupLevel.PAID);
}
@@ -204,10 +204,7 @@ public class BackupAuthManagerTest {
.mediaCredential(backupAuthTestUtil.getRequest(mediaBackupKey, aci))
.build();
assertThat(authManager.getBackupAuthCredentials(account,
credentialType,
clock.instant().truncatedTo(ChronoUnit.DAYS),
clock.instant().plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS)).join())
assertThat(authManager.getBackupAuthCredentials(account, credentialType, range(Duration.ofDays(1))).join())
.hasSize(2);
}
@@ -219,10 +216,8 @@ public class BackupAuthManagerTest {
final Account account = new MockAccountBuilder().build();
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> authManager.getBackupAuthCredentials(account,
credentialType,
clock.instant().truncatedTo(ChronoUnit.DAYS),
clock.instant().plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS)).join())
.isThrownBy(() ->
authManager.getBackupAuthCredentials(account, credentialType, range(Duration.ofDays(1))).join())
.extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.Code.NOT_FOUND);
}
@@ -246,12 +241,11 @@ public class BackupAuthManagerTest {
.messagesCredential(backupAuthTestUtil.getRequest(messagesBackupKey, aci))
.build();
final Instant start = clock.instant().truncatedTo(ChronoUnit.DAYS);
final List<BackupAuthManager.Credential> creds = authManager.getBackupAuthCredentials(account,
credentialType, start, start.plus(Duration.ofDays(7))).join();
credentialType, range(Duration.ofDays(7))).join();
assertThat(creds).hasSize(8);
Instant redemptionTime = start;
Instant redemptionTime = clock.instant().truncatedTo(ChronoUnit.DAYS);
for (BackupAuthManager.Credential cred : creds) {
assertThat(requestContext
.receiveResponse(cred.credential(), redemptionTime, backupAuthTestUtil.params.getPublicParams())
@@ -262,51 +256,10 @@ public class BackupAuthManagerTest {
}
}
static Stream<Arguments> invalidCredentialTimeWindows() {
final Duration max = Duration.ofDays(7);
final Instant day0 = Instant.EPOCH;
final Instant day1 = Instant.EPOCH.plus(Duration.ofDays(1));
return Stream.of(
// non-truncated start
Arguments.of(Instant.ofEpochSecond(100), day0.plus(max), Instant.ofEpochSecond(100)),
// non-truncated end
Arguments.of(day0, Instant.ofEpochSecond(1).plus(max), Instant.ofEpochSecond(100)),
// start to old
Arguments.of(day0, day0.plus(max), day1),
// end to new
Arguments.of(day1, day1.plus(max), day0),
// end before start
Arguments.of(day1, day0, day1),
// window too big
Arguments.of(day0, day0.plus(max).plus(Duration.ofDays(1)), Instant.ofEpochSecond(100))
);
}
@ParameterizedTest
@MethodSource
void invalidCredentialTimeWindows(final Instant requestRedemptionStart, final Instant requestRedemptionEnd,
final Instant now) {
final BackupAuthManager authManager = create();
final Account account = new MockAccountBuilder()
.messagesCredential(backupAuthTestUtil.getRequest(messagesBackupKey, aci))
.mediaCredential(backupAuthTestUtil.getRequest(mediaBackupKey, aci))
.build();
clock.pin(now);
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(
() -> authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, requestRedemptionStart, requestRedemptionEnd).join())
.extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.Code.INVALID_ARGUMENT);
}
@Test
void expiringBackupPayment() throws VerificationFailedException {
clock.pin(Instant.ofEpochSecond(1));
final Instant day0 = Instant.EPOCH;
final Instant day4 = Instant.EPOCH.plus(Duration.ofDays(4));
final Instant dayMax = day0.plus(BackupAuthManager.MAX_REDEMPTION_DURATION);
final BackupAuthManager authManager = create();
@@ -316,8 +269,11 @@ public class BackupAuthManagerTest {
.backupVoucher(new Account.BackupVoucher(201, day4))
.build();
final List<BackupAuthManager.Credential> creds = authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, day0, dayMax).join();
Instant redemptionTime = day0;
final List<BackupAuthManager.Credential> creds = authManager.getBackupAuthCredentials(
account,
BackupCredentialType.MESSAGES,
range(RedemptionRange.MAX_REDEMPTION_DURATION)).join();
Instant redemptionTime = Instant.EPOCH;
final BackupAuthCredentialRequestContext requestContext = BackupAuthCredentialRequestContext.create(
messagesBackupKey, aci);
for (int i = 0; i < creds.size(); i++) {
@@ -355,7 +311,7 @@ public class BackupAuthManagerTest {
when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(updated));
clock.pin(day2.plus(Duration.ofSeconds(1)));
assertThat(authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, day2, day2.plus(Duration.ofDays(7))).join())
assertThat(authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, range(Duration.ofDays(7))).join())
.hasSize(8);
@SuppressWarnings("unchecked") final ArgumentCaptor<Consumer<Account>> accountUpdater = ArgumentCaptor.forClass(
@@ -620,7 +576,7 @@ public class BackupAuthManagerTest {
private class MockAccountBuilder {
private Account account = mock(Account.class);
private final Account account = mock(Account.class);
MockAccountBuilder() {
when(account.getUuid()).thenReturn(aci);
@@ -681,4 +637,9 @@ public class BackupAuthManagerTest {
.thenReturn(rateLimitPaidMediaBackupId ? denyLimiter : allowLimiter);
return limiters;
}
private RedemptionRange range(Duration length) {
final Instant start = clock.instant().truncatedTo(ChronoUnit.DAYS);
return RedemptionRange.inclusive(clock, start, start.plus(length));
}
}

View File

@@ -21,6 +21,7 @@ import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.backups.BackupLevel;
import org.whispersystems.textsecuregcm.auth.RedemptionRange;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -73,6 +74,8 @@ public class BackupAuthTestUtil {
case FREE -> null;
case PAID -> new Account.BackupVoucher(201L, redemptionEnd.plus(1, ChronoUnit.SECONDS));
});
return issuer.getBackupAuthCredentials(account, credentialType, redemptionStart, redemptionEnd).join();
final RedemptionRange redemptionRange;
redemptionRange = RedemptionRange.inclusive(clock, redemptionStart, redemptionEnd);
return issuer.getBackupAuthCredentials(account, credentialType, redemptionRange).join();
}
}

View File

@@ -64,6 +64,7 @@ import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.RedemptionRange;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
import org.whispersystems.textsecuregcm.backup.BackupAuthTestUtil;
import org.whispersystems.textsecuregcm.backup.BackupManager;
@@ -322,14 +323,15 @@ public class ArchiveControllerTest {
public void getCredentials() {
final Instant start = Instant.now().truncatedTo(ChronoUnit.DAYS);
final Instant end = start.plus(Duration.ofDays(1));
final RedemptionRange expectedRange = RedemptionRange.inclusive(Clock.systemUTC(), start, end);
final Map<BackupCredentialType, List<BackupAuthManager.Credential>> expectedCredentialsByType =
EnumMapUtil.toEnumMap(BackupCredentialType.class, credentialType -> backupAuthTestUtil.getCredentials(
BackupLevel.PAID, backupAuthTestUtil.getRequest(messagesBackupKey, aci), credentialType, start, end));
expectedCredentialsByType.forEach((credentialType, expectedCredentials) ->
when(backupAuthManager.getBackupAuthCredentials(any(), eq(credentialType), eq(start), eq(end))).thenReturn(
CompletableFuture.completedFuture(expectedCredentials)));
when(backupAuthManager.getBackupAuthCredentials(any(), eq(credentialType), eq(expectedRange)))
.thenReturn(CompletableFuture.completedFuture(expectedCredentials)));
final ArchiveController.BackupAuthCredentialsResponse credentialResponse = resources.getJerseyTest()
.target("v1/archives/auth")

View File

@@ -28,9 +28,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.Mock;
import org.signal.chat.backup.BackupsGrpc;
import org.signal.chat.backup.GetBackupAuthCredentialsRequest;
@@ -51,6 +49,7 @@ import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequestContext;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialResponse;
import org.signal.libsignal.zkgroup.receipts.ReceiptSerial;
import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.whispersystems.textsecuregcm.auth.RedemptionRange;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
import org.whispersystems.textsecuregcm.backup.BackupAuthTestUtil;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
@@ -193,6 +192,7 @@ class BackupsGrpcServiceTest extends SimpleBaseGrpcTest<BackupsGrpcService, Back
void getCredentials() {
final Instant start = Instant.now().truncatedTo(ChronoUnit.DAYS);
final Instant end = start.plus(Duration.ofDays(1));
final RedemptionRange expectedRange = RedemptionRange.inclusive(Clock.systemUTC(), start, end);
final Map<BackupCredentialType, List<BackupAuthManager.Credential>> expectedCredentialsByType =
EnumMapUtil.toEnumMap(BackupCredentialType.class, credentialType -> backupAuthTestUtil.getCredentials(
@@ -200,7 +200,7 @@ class BackupsGrpcServiceTest extends SimpleBaseGrpcTest<BackupsGrpcService, Back
start, end));
expectedCredentialsByType.forEach((credentialType, expectedCredentials) ->
when(backupAuthManager.getBackupAuthCredentials(any(), eq(credentialType), eq(start), eq(end)))
when(backupAuthManager.getBackupAuthCredentials(any(), eq(credentialType), eq(expectedRange)))
.thenReturn(CompletableFuture.completedFuture(expectedCredentials)));
final GetBackupAuthCredentialsResponse credentialResponse = authenticatedServiceStub().getBackupAuthCredentials(