Convert VerificationSessions to sync DynamoDB client

This commit is contained in:
Chris Eager
2026-05-19 17:30:53 -05:00
committed by Jon Chambers
parent defbc1c853
commit 66b0ed16d1
10 changed files with 127 additions and 217 deletions
@@ -47,7 +47,7 @@ public class IntegrationTools {
config.dynamoDbTables().registrationRecovery(), Duration.ofDays(1), dynamoDbAsyncClient, Clock.systemUTC());
final VerificationSessions verificationSessions = new VerificationSessions(
dynamoDbAsyncClient, config.dynamoDbTables().verificationSessions(), Clock.systemUTC());
dynamoDbClient, config.dynamoDbTables().verificationSessions(), Clock.systemUTC());
return new IntegrationTools(
new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords),
@@ -75,9 +75,8 @@ public class IntegrationTools {
.thenRun(Util.NOOP);
}
public CompletableFuture<Optional<String>> peekVerificationSessionPushChallenge(final String sessionId) {
return verificationSessionManager.findForId(sessionId)
.thenApply(maybeSession -> maybeSession.map(VerificationSession::pushChallenge));
public Optional<String> peekVerificationSessionPushChallenge(final String sessionId) {
return verificationSessionManager.findForId(sessionId).map(VerificationSession::pushChallenge);
}
public void clearChangeNumberWaitingPeriod(TestUser user) {
@@ -117,7 +117,7 @@ public final class Operations {
}
public static String peekVerificationSessionPushChallenge(final String sessionId) {
return INTEGRATION_TOOLS.peekVerificationSessionPushChallenge(sessionId).join()
return INTEGRATION_TOOLS.peekVerificationSessionPushChallenge(sessionId)
.orElseThrow(() -> new RuntimeException("push challenge not found for the verification session"));
}
@@ -519,7 +519,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
dynamoDbAsyncClient,
clock);
final VerificationSessions verificationSessions = new VerificationSessions(dynamoDbAsyncClient,
final VerificationSessions verificationSessions = new VerificationSessions(dynamoDbClient,
config.getDynamoDbTables().getVerificationSessions().getTableName(), clock);
final ClientResources sharedClientResources = ClientResources.builder()
@@ -61,7 +61,6 @@ import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.Strings;
import org.apache.http.HttpStatus;
import org.slf4j.Logger;
@@ -116,7 +115,6 @@ public class VerificationController {
private static final Logger logger = LoggerFactory.getLogger(VerificationController.class);
private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15);
private static final Duration DYNAMODB_TIMEOUT = Duration.ofSeconds(5);
private static final SecureRandom RANDOM = new SecureRandom();
@@ -264,7 +262,7 @@ public class VerificationController {
// if a push challenge sent in `handlePushToken` doesn't arrive in time
verificationSession.requestedInformation().add(VerificationSession.Information.CAPTCHA);
storeVerificationSession(verificationSession);
verificationSessionManager.insert(verificationSession);
return buildResponse(registrationServiceSession, verificationSession);
}
@@ -337,24 +335,12 @@ public class VerificationController {
} finally {
// Each of the handle* methods may update requestedInformation, submittedInformation, and allowedToRequestCode,
// and we want to be sure to store a changes, even if a later method throws
updateStoredVerificationSession(verificationSession);
verificationSessionManager.update(verificationSession);
}
return buildResponse(registrationServiceSession, verificationSession);
}
private void storeVerificationSession(final VerificationSession verificationSession) {
verificationSessionManager.insert(verificationSession)
.orTimeout(DYNAMODB_TIMEOUT.toSeconds(), TimeUnit.SECONDS)
.join();
}
private void updateStoredVerificationSession(final VerificationSession verificationSession) {
verificationSessionManager.update(verificationSession)
.orTimeout(DYNAMODB_TIMEOUT.toSeconds(), TimeUnit.SECONDS)
.join();
}
/**
* If {@code pushTokenAndType} values are not {@code null}, sends a push challenge. If there is no existing push
* challenge in the session, one will be created, set on the returned session record, and
@@ -907,8 +893,7 @@ public class VerificationController {
private VerificationSession retrieveVerificationSession(final RegistrationServiceSession registrationServiceSession) {
return verificationSessionManager.findForId(registrationServiceSession.encodedSessionId())
.orTimeout(5, TimeUnit.SECONDS)
.join().orElseThrow(NotFoundException::new);
.orElseThrow(NotFoundException::new);
}
/**
@@ -14,16 +14,16 @@ import java.time.Clock;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
public abstract class SerializedExpireableJsonDynamoStore<T> {
@@ -34,7 +34,7 @@ public abstract class SerializedExpireableJsonDynamoStore<T> {
long getExpirationEpochSeconds();
}
private final DynamoDbAsyncClient dynamoDbClient;
private final DynamoDbClient dynamoDbClient;
private final String tableName;
private final Clock clock;
private final Class<T> deserializationTargetClass;
@@ -47,7 +47,7 @@ public abstract class SerializedExpireableJsonDynamoStore<T> {
private final Logger log = LoggerFactory.getLogger(getClass());
public SerializedExpireableJsonDynamoStore(final DynamoDbAsyncClient dynamoDbClient, final String tableName,
public SerializedExpireableJsonDynamoStore(final DynamoDbClient dynamoDbClient, final String tableName,
final Clock clock) {
this.dynamoDbClient = dynamoDbClient;
this.tableName = tableName;
@@ -67,18 +67,16 @@ public abstract class SerializedExpireableJsonDynamoStore<T> {
}
}
public CompletableFuture<Void> insert(final String key, final T v) {
return put(key, v, builder -> builder.expressionAttributeNames(Map.of(
"#key", KEY_KEY
)).conditionExpression("attribute_not_exists(#key)"));
public void insert(final String key, final T v) {
put(key, v, builder -> builder.expressionAttributeNames(Map.of("#key", KEY_KEY))
.conditionExpression("attribute_not_exists(#key)"));
}
public CompletableFuture<Void> update(final String key, final T v) {
return put(key, v, ignored -> {
});
public void update(final String key, final T v) {
put(key, v, _ -> {});
}
private CompletableFuture<Void> put(final String key, final T v,
private void put(final String key, final T v,
final Consumer<PutItemRequest.Builder> putRequestCustomizer) {
try {
final Map<String, AttributeValue> attributeValueMap = new HashMap<>(Map.of(
@@ -93,9 +91,7 @@ public abstract class SerializedExpireableJsonDynamoStore<T> {
.item(attributeValueMap);
putRequestCustomizer.accept(builder);
return dynamoDbClient.putItem(builder.build())
.thenRun(() -> {
});
dynamoDbClient.putItem(builder.build());
} catch (final JsonProcessingException e) {
// This should never happen when writing directly to a string except in cases of serious misconfiguration, which
// would be caught by tests.
@@ -107,24 +103,23 @@ public abstract class SerializedExpireableJsonDynamoStore<T> {
return v.getExpirationEpochSeconds();
}
public CompletableFuture<Optional<T>> findForKey(final String key) {
return dynamoDbClient.getItem(GetItemRequest.builder()
public Optional<T> findForKey(final String key) {
final GetItemResponse response = dynamoDbClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.consistentRead(true)
.key(Map.of(KEY_KEY, AttributeValues.fromString(key)))
.build())
.thenApply(response -> {
try {
return response.hasItem()
? filterMaybeExpiredValue(
SystemMapper.jsonMapper()
.readValue(response.item().get(ATTR_SERIALIZED_VALUE).s(), deserializationTargetClass))
: Optional.empty();
} catch (final JsonProcessingException e) {
log.error("Failed to parse stored value", e);
return Optional.empty();
}
});
.build());
try {
return response.hasItem()
? filterMaybeExpiredValue(
SystemMapper.jsonMapper()
.readValue(response.item().get(ATTR_SERIALIZED_VALUE).s(), deserializationTargetClass))
: Optional.empty();
} catch (final JsonProcessingException e) {
log.error("Failed to parse stored value", e);
return Optional.empty();
}
}
private Optional<T> filterMaybeExpiredValue(T v) {
@@ -139,13 +134,11 @@ public abstract class SerializedExpireableJsonDynamoStore<T> {
return Optional.of(v);
}
public CompletableFuture<Void> remove(final String key) {
return dynamoDbClient.deleteItem(DeleteItemRequest.builder()
public void remove(final String key) {
dynamoDbClient.deleteItem(DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(KEY_KEY, AttributeValues.fromString(key)))
.build())
.thenRun(() -> {
});
.build());
}
}
@@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.storage;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.registration.VerificationSession;
public class VerificationSessionManager {
@@ -17,15 +16,15 @@ public class VerificationSessionManager {
this.verificationSessions = verificationSessions;
}
public CompletableFuture<Void> insert(final VerificationSession verificationSession) {
return verificationSessions.insert(verificationSession.sessionId(), verificationSession);
public void insert(final VerificationSession verificationSession) {
verificationSessions.insert(verificationSession.sessionId(), verificationSession);
}
public CompletableFuture<Void> update(final VerificationSession verificationSession) {
return verificationSessions.update(verificationSession.sessionId(), verificationSession);
public void update(final VerificationSession verificationSession) {
verificationSessions.update(verificationSession.sessionId(), verificationSession);
}
public CompletableFuture<Optional<VerificationSession>> findForId(final String encodedSessionId) {
public Optional<VerificationSession> findForId(final String encodedSessionId) {
return verificationSessions.findForKey(encodedSessionId);
}
@@ -7,11 +7,11 @@ package org.whispersystems.textsecuregcm.storage;
import java.time.Clock;
import org.whispersystems.textsecuregcm.registration.VerificationSession;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
public class VerificationSessions extends SerializedExpireableJsonDynamoStore<VerificationSession> {
public VerificationSessions(final DynamoDbAsyncClient dynamoDbClient, final String tableName, final Clock clock) {
public VerificationSessions(final DynamoDbClient dynamoDbClient, final String tableName, final Clock clock) {
super(dynamoDbClient, tableName, clock);
}
}
@@ -240,8 +240,6 @@ class VerificationControllerTest {
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, requestedNumber, false, null, null, null,
SESSION_EXPIRATION_SECONDS)));
when(verificationSessionManager.insert(any()))
.thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session")
@@ -281,8 +279,6 @@ class VerificationControllerTest {
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
SESSION_EXPIRATION_SECONDS)));
when(verificationSessionManager.insert(any()))
.thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session")
@@ -302,8 +298,6 @@ class VerificationControllerTest {
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
SESSION_EXPIRATION_SECONDS)));
when(verificationSessionManager.insert(any()))
.thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session")
@@ -337,8 +331,6 @@ class VerificationControllerTest {
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
SESSION_EXPIRATION_SECONDS)));
when(verificationSessionManager.insert(any()))
.thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.getByE164(NUMBER))
.thenReturn(isReregistration ? Optional.of(mock(Account.class)) : Optional.empty());
@@ -400,13 +392,9 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(
.thenReturn(Optional.of(
new VerificationSession(encodedSessionId, null, null, List.of(VerificationSession.Information.CAPTCHA), Collections.emptyList(),
null, null, false, clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
null, null, false, clock.millis(), clock.millis(), registrationServiceSession.expiration())));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId)
@@ -445,12 +433,8 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, false,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, false,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
doThrow(RateLimitExceededException.class)
.when(captchaLimiter).validate(anyString());
@@ -481,12 +465,8 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, false,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, false,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
doThrow(RateLimitExceededException.class)
.when(pushChallengeLimiter).validate(anyString());
@@ -517,12 +497,9 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, "challenge", null, List.of(VerificationSession.Information.PUSH_CHALLENGE),
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, "challenge", null, List.of(VerificationSession.Information.PUSH_CHALLENGE),
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
registrationServiceSession.expiration())));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId)
@@ -551,17 +528,13 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, List.of(VerificationSession.Information.CAPTCHA),
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, List.of(VerificationSession.Information.CAPTCHA),
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));
registrationServiceSession.expiration())));
when(registrationCaptchaManager.assessCaptcha(any(), any(), any(), any()))
.thenReturn(Optional.of(AssessmentResult.invalid()));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId)
.request()
@@ -598,17 +571,14 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId,
.thenReturn(Optional.of(new VerificationSession(encodedSessionId,
"challenge",
null,
List.of(VerificationSession.Information.CAPTCHA),
List.of(VerificationSession.Information.PUSH_CHALLENGE),
null, null, false,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId)
.request()
@@ -645,12 +615,9 @@ class VerificationControllerTest {
.thenReturn(CompletableFuture.completedFuture(
Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, "challenge", null, List.of(), List.of(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, "challenge", null, List.of(), List.of(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
when(registrationRecoveryPasswordsManager.remove(PNI))
.thenReturn(CompletableFuture.completedFuture(null));
@@ -683,15 +650,12 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId,
.thenReturn(Optional.of(new VerificationSession(encodedSessionId,
"challenge",
null,
List.of(VerificationSession.Information.PUSH_CHALLENGE, VerificationSession.Information.CAPTCHA),
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
registrationServiceSession.expiration())));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId)
@@ -729,16 +693,13 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, List.of(VerificationSession.Information.CAPTCHA),
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, List.of(VerificationSession.Information.CAPTCHA),
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));
registrationServiceSession.expiration())));
when(registrationCaptchaManager.assessCaptcha(any(), any(), any(), any()))
.thenReturn(Optional.of(AssessmentResult.alwaysValid()));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId)
@@ -776,19 +737,16 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId,
.thenReturn(Optional.of(new VerificationSession(encodedSessionId,
"challenge",
null,
List.of(VerificationSession.Information.CAPTCHA, VerificationSession.Information.CAPTCHA),
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));
registrationServiceSession.expiration())));
when(registrationCaptchaManager.assessCaptcha(any(), any(), any(), any()))
.thenReturn(Optional.of(AssessmentResult.alwaysValid()));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId)
@@ -825,19 +783,16 @@ class VerificationControllerTest {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId,
.thenReturn(Optional.of(new VerificationSession(encodedSessionId,
null,
null,
List.of(VerificationSession.Information.CAPTCHA),
Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));
registrationServiceSession.expiration())));
when(registrationCaptchaManager.assessCaptcha(any(), any(), any(), any()))
.thenThrow(new IOException("expected service error"));
when(verificationSessionManager.update(any()))
.thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId)
@@ -892,7 +847,7 @@ class VerificationControllerTest {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(verificationSessionManager.findForId(encodeSessionId(SESSION_ID)))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
.thenReturn(Optional.empty());
Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodeSessionId(SESSION_ID))
@@ -940,7 +895,7 @@ class VerificationControllerTest {
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
SESSION_EXPIRATION_SECONDS))));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(mock(VerificationSession.class))));
.thenReturn(Optional.of(mock(VerificationSession.class)));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId)
@@ -962,7 +917,7 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(mock(VerificationSession.class))));
.thenReturn(Optional.of(mock(VerificationSession.class)));
when(registrationRecoveryPasswordsManager.remove(PNI))
.thenReturn(CompletableFuture.completedFuture(null));
@@ -986,9 +941,8 @@ class VerificationControllerTest {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationServiceClient.sendVerificationCode(any(), any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(registrationServiceSession));
when(registrationRecoveryPasswordsManager.remove(PNI)).thenReturn(CompletableFuture.completedFuture(null));
@@ -1018,9 +972,9 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(new VerificationSession(encodedSessionId, null, null,
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null,
List.of(VerificationSession.Information.CAPTCHA), Collections.emptyList(), null, null, false, clock.millis(), clock.millis(),
registrationServiceSession.expiration()))));
registrationServiceSession.expiration())));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId + "/code")
@@ -1049,9 +1003,8 @@ class VerificationControllerTest {
Optional.of(
registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, false,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, false,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session/" + encodedSessionId + "/code")
@@ -1077,9 +1030,8 @@ class VerificationControllerTest {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationServiceClient.sendVerificationCode(any(), any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(
new CompletionException(new VerificationSessionRateLimitExceededException(registrationServiceSession,
@@ -1109,9 +1061,8 @@ class VerificationControllerTest {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationServiceClient.sendVerificationCode(any(), any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(
new CompletionException(new TransportNotAllowedException(registrationServiceSession))));
@@ -1142,9 +1093,8 @@ class VerificationControllerTest {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationServiceClient.sendVerificationCode(any(), any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(registrationServiceSession));
@@ -1193,9 +1143,8 @@ class VerificationControllerTest {
.thenReturn(CompletableFuture.completedFuture(
Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationServiceClient.sendVerificationCode(any(), any(), any(), any(), any(), any()))
.thenReturn(
@@ -1237,9 +1186,8 @@ class VerificationControllerTest {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationServiceClient.sendVerificationCode(any(), any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(
@@ -1271,9 +1219,8 @@ class VerificationControllerTest {
.thenReturn(CompletableFuture.completedFuture(
Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationServiceClient.checkVerificationCode(any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new RuntimeException())));
@@ -1298,9 +1245,8 @@ class VerificationControllerTest {
.thenReturn(CompletableFuture.completedFuture(
Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationRecoveryPasswordsManager.remove(PNI))
.thenReturn(CompletableFuture.completedFuture(null));
@@ -1335,9 +1281,8 @@ class VerificationControllerTest {
.thenReturn(CompletableFuture.completedFuture(
Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
// There is no explicit indication in the exception that no code has been sent, but we treat all RegistrationServiceExceptions
// in which the response has a session object as conflicted state
@@ -1372,9 +1317,8 @@ class VerificationControllerTest {
.thenReturn(CompletableFuture.completedFuture(
Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationServiceClient.checkVerificationCode(any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new RegistrationServiceException(null))));
@@ -1398,9 +1342,8 @@ class VerificationControllerTest {
.thenReturn(CompletableFuture.completedFuture(
Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationServiceClient.checkVerificationCode(any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(
new CompletionException(new VerificationSessionRateLimitExceededException(registrationServiceSession,
@@ -1430,9 +1373,8 @@ class VerificationControllerTest {
.thenReturn(CompletableFuture.completedFuture(
Optional.of(registrationServiceSession)));
when(verificationSessionManager.findForId(any()))
.thenReturn(CompletableFuture.completedFuture(
Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration()))));
.thenReturn(Optional.of(new VerificationSession(encodedSessionId, null, null, Collections.emptyList(), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), registrationServiceSession.expiration())));
when(registrationRecoveryPasswordsManager.remove(any()))
.thenReturn(CompletableFuture.completedFuture(true));
@@ -16,12 +16,11 @@ import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
@@ -54,7 +53,7 @@ class SerializedExpireableJsonDynamoStoreTest {
private SerializedExpireableJsonDynamoStore<T> store;
abstract SerializedExpireableJsonDynamoStore<T> getStore(final DynamoDbAsyncClient dynamoDbClient,
abstract SerializedExpireableJsonDynamoStore<T> getStore(final DynamoDbClient dynamoDbClient,
final String tableName);
abstract T testValue(final String v);
@@ -63,29 +62,29 @@ class SerializedExpireableJsonDynamoStoreTest {
@BeforeEach
void setUp() {
store = getStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), TABLE_NAME);
store = getStore(DYNAMO_DB_EXTENSION.getDynamoDbClient(), TABLE_NAME);
}
@Test
void testStoreAndFind() throws Exception {
assertEquals(Optional.empty(), store.findForKey(KEY).get(1, TimeUnit.SECONDS));
void testStoreAndFind() {
assertEquals(Optional.empty(), store.findForKey(KEY));
final T original = testValue("1234");
final T second = testValue("5678");
store.insert(KEY, original).get(1, TimeUnit.SECONDS);
store.insert(KEY, original);
{
final Optional<T> maybeValue = store.findForKey(KEY).get(1, TimeUnit.SECONDS);
final Optional<T> maybeValue = store.findForKey(KEY);
assertTrue(maybeValue.isPresent());
assertEquals(original, maybeValue.get());
}
assertThrows(Exception.class, () -> store.insert(KEY, second).get(1, TimeUnit.SECONDS));
assertDoesNotThrow(() -> store.update(KEY, second).get(1, TimeUnit.SECONDS));
assertThrows(Exception.class, () -> store.insert(KEY, second));
assertDoesNotThrow(() -> store.update(KEY, second));
{
final Optional<T> maybeValue = store.findForKey(KEY).get(1, TimeUnit.SECONDS);
final Optional<T> maybeValue = store.findForKey(KEY);
assertTrue(maybeValue.isPresent());
assertEquals(second, maybeValue.get());
@@ -93,20 +92,20 @@ class SerializedExpireableJsonDynamoStoreTest {
}
@Test
void testRemove() throws Exception {
assertEquals(Optional.empty(), store.findForKey(KEY).get(1, TimeUnit.SECONDS));
void testRemove() {
assertEquals(Optional.empty(), store.findForKey(KEY));
store.insert(KEY, testValue("1234")).get(1, TimeUnit.SECONDS);
assertTrue(store.findForKey(KEY).get(1, TimeUnit.SECONDS).isPresent());
store.insert(KEY, testValue("1234"));
assertTrue(store.findForKey(KEY).isPresent());
store.remove(KEY).get(1, TimeUnit.SECONDS);
assertFalse(store.findForKey(KEY).get(1, TimeUnit.SECONDS).isPresent());
store.remove(KEY);
assertFalse(store.findForKey(KEY).isPresent());
final T v = maybeExpiredTestValue("1234");
store.insert(KEY, v).get(1, TimeUnit.SECONDS);
store.insert(KEY, v);
assertEquals(v instanceof SerializedExpireableJsonDynamoStore.Expireable,
store.findForKey(KEY).get(1, TimeUnit.SECONDS).isEmpty());
store.findForKey(KEY).isEmpty());
}
}
@@ -126,7 +125,7 @@ class SerializedExpireableJsonDynamoStoreTest {
class ExpiresStore extends SerializedExpireableJsonDynamoStore<Expires> {
public ExpiresStore(final DynamoDbAsyncClient dynamoDbClient, final String tableName) {
public ExpiresStore(final DynamoDbClient dynamoDbClient, final String tableName) {
super(dynamoDbClient, tableName, clock);
}
}
@@ -136,7 +135,7 @@ class SerializedExpireableJsonDynamoStoreTest {
Duration.ofHours(1)).toEpochMilli();
@Override
SerializedExpireableJsonDynamoStore<Expires> getStore(final DynamoDbAsyncClient dynamoDbClient,
SerializedExpireableJsonDynamoStore<Expires> getStore(final DynamoDbClient dynamoDbClient,
final String tableName) {
return new ExpiresStore(dynamoDbClient, tableName);
}
@@ -162,13 +161,13 @@ class SerializedExpireableJsonDynamoStoreTest {
class DoesNotExpireStore extends SerializedExpireableJsonDynamoStore<DoesNotExpire> {
public DoesNotExpireStore(final DynamoDbAsyncClient dynamoDbClient, final String tableName) {
public DoesNotExpireStore(final DynamoDbClient dynamoDbClient, final String tableName) {
super(dynamoDbClient, tableName, clock);
}
}
@Override
SerializedExpireableJsonDynamoStore<DoesNotExpire> getStore(final DynamoDbAsyncClient dynamoDbClient,
SerializedExpireableJsonDynamoStore<DoesNotExpire> getStore(final DynamoDbClient dynamoDbClient,
final String tableName) {
return new DoesNotExpireStore(dynamoDbClient, tableName);
}
@@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -17,14 +16,12 @@ import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletionException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.registration.VerificationSession;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.telephony.CarrierData;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
class VerificationSessionsTest {
@@ -39,7 +36,7 @@ class VerificationSessionsTest {
@BeforeEach
void setUp() {
verificationSessions = new VerificationSessions(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.VERIFICATION_SESSIONS.tableName(), clock);
DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.VERIFICATION_SESSIONS.tableName(), clock);
}
@Test
@@ -62,30 +59,26 @@ class VerificationSessionsTest {
final String sessionId = "sessionId";
final Optional<VerificationSession> absentSession = verificationSessions.findForKey(sessionId).join();
final Optional<VerificationSession> absentSession = verificationSessions.findForKey(sessionId);
assertTrue(absentSession.isEmpty());
final VerificationSession session = new VerificationSession(sessionId, null, new CarrierData("Test", CarrierData.LineType.MOBILE, Optional.of("123"), Optional.empty(), Optional.empty(), Optional.empty()),
List.of(VerificationSession.Information.PUSH_CHALLENGE), Collections.emptyList(), null, null, true,
clock.millis(), clock.millis(), Duration.ofMinutes(1).toSeconds());
verificationSessions.insert(sessionId, session).join();
verificationSessions.insert(sessionId, session);
assertEquals(session, verificationSessions.findForKey(sessionId).join().orElseThrow());
assertEquals(session, verificationSessions.findForKey(sessionId).orElseThrow());
final CompletionException ce = assertThrows(CompletionException.class,
() -> verificationSessions.insert(sessionId, session).join());
final Throwable t = ExceptionUtils.unwrap(ce);
assertInstanceOf(ConditionalCheckFailedException.class, t,
assertThrows(ConditionalCheckFailedException.class, () -> verificationSessions.insert(sessionId, session),
"inserting with the same key should fail conditional checks");
final VerificationSession updatedSession = new VerificationSession(sessionId, null, new CarrierData("Test", CarrierData.LineType.MOBILE, Optional.of("123"), Optional.empty(), Optional.empty(), Optional.empty()), Collections.emptyList(),
List.of(VerificationSession.Information.PUSH_CHALLENGE), null, null, true, clock.millis(), clock.millis(),
Duration.ofMinutes(2).toSeconds());
verificationSessions.update(sessionId, updatedSession).join();
verificationSessions.update(sessionId, updatedSession);
assertEquals(updatedSession, verificationSessions.findForKey(sessionId).join().orElseThrow());
assertEquals(updatedSession, verificationSessions.findForKey(sessionId).orElseThrow());
});
}