Add collation key to registration service session creation rpc call

This commit is contained in:
Chris Eager
2025-01-22 12:01:12 -06:00
committed by Chris Eager
parent 5cc76f48aa
commit 47550d48e7
10 changed files with 67 additions and 17 deletions

View File

@@ -3,9 +3,11 @@ package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonTypeName;
import io.dropwizard.core.setup.Environment;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import java.io.IOException;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.registration.IdentityTokenCallCredentials;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
@@ -14,7 +16,8 @@ public record RegistrationServiceConfiguration(@NotBlank String host,
int port,
@NotBlank String credentialConfigurationJson,
@NotBlank String identityTokenAudience,
@NotBlank String registrationCaCertificate) implements
@NotBlank String registrationCaCertificate,
@NotNull SecretBytes collationKeySalt) implements
RegistrationServiceClientFactory {
@Override
@@ -26,7 +29,7 @@ public record RegistrationServiceConfiguration(@NotBlank String host,
environment.lifecycle().manage(callCredentials);
return new RegistrationServiceClient(host, port, callCredentials, registrationCaCertificate,
return new RegistrationServiceClient(host, port, callCredentials, registrationCaCertificate, collationKeySalt.value(),
identityRefreshExecutor);
} catch (IOException e) {
throw new RuntimeException(e);

View File

@@ -173,7 +173,8 @@ public class VerificationController {
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed",
schema = @Schema(implementation = Integer.class)))
public VerificationSessionResponse createSession(@NotNull @Valid final CreateVerificationSessionRequest request)
public VerificationSessionResponse createSession(@NotNull @Valid final CreateVerificationSessionRequest request,
@Context final ContainerRequestContext requestContext)
throws RateLimitExceededException, ObsoletePhoneNumberFormatException {
final Pair<String, PushNotification.TokenType> pushTokenAndType = validateAndExtractPushToken(
@@ -188,7 +189,9 @@ public class VerificationController {
final RegistrationServiceSession registrationServiceSession;
try {
registrationServiceSession = registrationServiceClient.createRegistrationSession(phoneNumber,
final String sourceHost = (String) requestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
registrationServiceSession = registrationServiceClient.createRegistrationSession(phoneNumber, sourceHost,
accountsManager.getByE164(request.getNumber()).isPresent(),
REGISTRATION_RPC_TIMEOUT).join();
} catch (final CancellationException e) {

View File

@@ -14,12 +14,15 @@ import io.grpc.TlsChannelCredentials;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.Base64;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import javax.crypto.Mac;
import org.apache.commons.lang3.StringUtils;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.signal.registration.rpc.CheckVerificationCodeRequest;
@@ -35,9 +38,12 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureUtil;
public class RegistrationServiceClient implements Managed {
private static final Base64.Encoder BASE64_UNPADDED_ENCODER = Base64.getEncoder().withoutPadding();
private final ManagedChannel channel;
private final RegistrationServiceGrpc.RegistrationServiceFutureStub stub;
private final Executor callbackExecutor;
private final byte[] collationKeySalt;
/**
* @param from an e164 in a {@code long} representation e.g. {@code 18005550123}
@@ -60,6 +66,7 @@ public class RegistrationServiceClient implements Managed {
final int port,
final CallCredentials callCredentials,
final String caCertificatePem,
final byte[] collationKeySalt,
final Executor callbackExecutor) throws IOException {
try (final ByteArrayInputStream certificateInputStream = new ByteArrayInputStream(caCertificatePem.getBytes(StandardCharsets.UTF_8))) {
@@ -73,19 +80,22 @@ public class RegistrationServiceClient implements Managed {
}
this.stub = RegistrationServiceGrpc.newFutureStub(channel).withCallCredentials(callCredentials);
this.collationKeySalt = collationKeySalt;
this.callbackExecutor = callbackExecutor;
}
public CompletableFuture<RegistrationServiceSession> createRegistrationSession(
final Phonenumber.PhoneNumber phoneNumber, final boolean accountExistsWithPhoneNumber, final Duration timeout) {
final Phonenumber.PhoneNumber phoneNumber, final String sourceHost, final boolean accountExistsWithPhoneNumber, final Duration timeout) {
final long e164 = Long.parseLong(
PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1));
final String rateLimitCollationKey = hmac(sourceHost, collationKeySalt);
return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout))
.createSession(CreateRegistrationSessionRequest.newBuilder()
.setE164(e164)
.setAccountExistsWithE164(accountExistsWithPhoneNumber)
.setRateLimitCollationKey(rateLimitCollationKey)
.build()), callbackExecutor)
.thenApply(response -> switch (response.getResponseCase()) {
case SESSION_METADATA -> buildSessionResponseFromMetadata(response.getSessionMetadata());
@@ -259,4 +269,18 @@ public class RegistrationServiceClient implements Managed {
channel.shutdown();
}
}
private static String hmac(String sourceHost, byte[] collationKeySalt) {
final Mac hmacSha256;
try {
hmacSha256 = Mac.getInstance("HmacSHA256");
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
hmacSha256.update(sourceHost.getBytes(StandardCharsets.UTF_8));
hmacSha256.update(collationKeySalt);
return BASE64_UNPADDED_ENCODER.encodeToString(hmacSha256.doFinal());
}
}