Explicitly create registration sessions

This commit is contained in:
Jon Chambers
2022-11-17 12:18:55 -05:00
committed by Jon Chambers
parent 9e1485de0a
commit 7018062606
4 changed files with 300 additions and 69 deletions

View File

@@ -29,6 +29,7 @@ import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletionException;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import javax.validation.Valid;
@@ -232,7 +233,7 @@ public class AccountController {
@PathParam("token") String pushToken,
@PathParam("number") String number,
@QueryParam("voip") @DefaultValue("true") boolean useVoip)
throws ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException {
throws ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException, RateLimitExceededException {
final PushNotification.TokenType tokenType = switch(pushType) {
case "apn" -> useVoip ? PushNotification.TokenType.APN_VOIP : PushNotification.TokenType.APN;
@@ -242,9 +243,18 @@ public class AccountController {
Util.requireNormalizedNumber(number);
String pushChallenge = generatePushChallenge();
StoredVerificationCode storedVerificationCode =
new StoredVerificationCode(null, clock.millis(), pushChallenge, null);
final Phonenumber.PhoneNumber phoneNumber;
try {
phoneNumber = PhoneNumberUtil.getInstance().parse(number, null);
} catch (final NumberParseException e) {
// This should never happen since we just verified that the number is already normalized
throw new BadRequestException("Bad phone number");
}
final String pushChallenge = generatePushChallenge();
final byte[] sessionId = createRegistrationSession(phoneNumber);
final StoredVerificationCode storedVerificationCode =
new StoredVerificationCode(null, clock.millis(), pushChallenge, sessionId);
pendingAccounts.store(number, storedVerificationCode);
pushNotificationManager.sendRegistrationChallengeNotification(pushToken, tokenType, storedVerificationCode.pushCode());
@@ -346,8 +356,16 @@ public class AccountController {
}
}).orElse(ClientType.UNKNOWN);
final byte[] sessionId = registrationServiceClient.sendRegistrationCode(phoneNumber,
messageTransport, clientType, acceptLanguage.orElse(null), REGISTRATION_RPC_TIMEOUT).join();
// During the transition to explicit session creation, some previously-stored records may not have a session ID;
// after the transition, we can assume that any existing record has an associated session ID.
final byte[] sessionId = maybeStoredVerificationCode.isPresent() && maybeStoredVerificationCode.get().sessionId() != null ?
maybeStoredVerificationCode.get().sessionId() : createRegistrationSession(phoneNumber);
registrationServiceClient.sendRegistrationCode(sessionId,
messageTransport,
clientType,
acceptLanguage.orElse(null),
REGISTRATION_RPC_TIMEOUT).join();
final StoredVerificationCode storedVerificationCode = new StoredVerificationCode(null,
clock.millis(),
@@ -940,4 +958,23 @@ public class AccountController {
return Hex.toStringCondensed(challenge);
}
private byte[] createRegistrationSession(final Phonenumber.PhoneNumber phoneNumber) throws RateLimitExceededException {
try {
return registrationServiceClient.createRegistrationSession(phoneNumber, REGISTRATION_RPC_TIMEOUT).join();
} catch (final CompletionException e) {
Throwable cause = e;
while (cause instanceof CompletionException) {
cause = cause.getCause();
}
if (cause instanceof RateLimitExceededException rateLimitExceededException) {
throw rateLimitExceededException;
}
throw e;
}
}
}

View File

@@ -16,15 +16,19 @@ import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.StringUtils;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.signal.registration.rpc.CheckVerificationCodeRequest;
import org.signal.registration.rpc.CheckVerificationCodeResponse;
import org.signal.registration.rpc.CreateRegistrationSessionRequest;
import org.signal.registration.rpc.RegistrationServiceGrpc;
import org.signal.registration.rpc.SendVerificationCodeRequest;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
public class RegistrationServiceClient implements Managed {
@@ -52,17 +56,36 @@ public class RegistrationServiceClient implements Managed {
this.callbackExecutor = callbackExecutor;
}
public CompletableFuture<byte[]> sendRegistrationCode(final Phonenumber.PhoneNumber phoneNumber,
public CompletableFuture<byte[]> createRegistrationSession(final Phonenumber.PhoneNumber phoneNumber, final Duration timeout) {
final long e164 = Long.parseLong(
PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1));
return toCompletableFuture(stub.withDeadline(toDeadline(timeout))
.createSession(CreateRegistrationSessionRequest.newBuilder()
.setE164(e164)
.build()))
.thenApply(response -> switch (response.getResponseCase()) {
case SESSION_METADATA -> response.getSessionMetadata().getSessionId().toByteArray();
case ERROR -> {
switch (response.getError().getErrorType()) {
case ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds())));
default -> throw new RuntimeException("Unrecognized error type from registration service: " + response.getError().getErrorType());
}
}
case RESPONSE_NOT_SET -> throw new RuntimeException("No response from registration service");
});
}
public CompletableFuture<byte[]> sendRegistrationCode(final byte[] sessionId,
final MessageTransport messageTransport,
final ClientType clientType,
@Nullable final String acceptLanguage,
final Duration timeout) {
final long e164 = Long.parseLong(
PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1));
final SendVerificationCodeRequest.Builder requestBuilder = SendVerificationCodeRequest.newBuilder()
.setE164(e164)
.setSessionId(ByteString.copyFrom(sessionId))
.setTransport(getRpcMessageTransport(messageTransport))
.setClientType(getRpcClientType(clientType));