Add collation key to registration service session creation rpc call

This commit is contained in:
Chris Eager
2025-01-22 12:01:12 -06:00
committed by Chris Eager
parent 5cc76f48aa
commit 47550d48e7
10 changed files with 67 additions and 17 deletions

View File

@@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import org.whispersystems.textsecuregcm.registration.ClientType;
import org.whispersystems.textsecuregcm.registration.MessageTransport;
@@ -35,12 +36,16 @@ public class StubRegistrationServiceClientFactory implements RegistrationService
@NotNull
private String registrationCaCertificate;
@JsonProperty
@NotNull
private SecretBytes collationKeySalt;
@Override
public RegistrationServiceClient build(final Environment environment, final Executor callbackExecutor,
final ScheduledExecutorService identityRefreshExecutor) {
try {
return new StubRegistrationServiceClient(registrationCaCertificate);
return new StubRegistrationServiceClient(registrationCaCertificate, collationKeySalt.value());
} catch (IOException e) {
throw new RuntimeException(e);
}
@@ -50,13 +55,13 @@ public class StubRegistrationServiceClientFactory implements RegistrationService
private final static Map<String, RegistrationServiceSession> SESSIONS = new ConcurrentHashMap<>();
public StubRegistrationServiceClient(final String registrationCaCertificate) throws IOException {
super("example.com", 8080, null, registrationCaCertificate, null);
public StubRegistrationServiceClient(final String registrationCaCertificate, final byte[] collationKeySalt) throws IOException {
super("example.com", 8080, null, registrationCaCertificate, collationKeySalt, null);
}
@Override
public CompletableFuture<RegistrationServiceSession> createRegistrationSession(
final Phonenumber.PhoneNumber phoneNumber, final boolean accountExistsWithPhoneNumber, final Duration timeout) {
final Phonenumber.PhoneNumber phoneNumber, final String sourceHost, final boolean accountExistsWithPhoneNumber, final Duration timeout) {
final String e164 = PhoneNumberUtil.getInstance()
.format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164);

View File

@@ -84,6 +84,7 @@ import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
class VerificationControllerTest {
@@ -120,6 +121,7 @@ class VerificationControllerTest {
.addProvider(new NonNormalizedPhoneNumberExceptionMapper())
.addProvider(new ObsoletePhoneNumberFormatExceptionMapper())
.addProvider(new RegistrationServiceSenderExceptionMapper())
.addProvider(new TestRemoteAddressFilterProvider("127.0.0.1"))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(
@@ -190,7 +192,7 @@ class VerificationControllerTest {
@Test
void createSessionRateLimited() {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null)));
final Invocation.Builder request = resources.getJerseyTest()
@@ -204,7 +206,7 @@ class VerificationControllerTest {
@Test
void createSessionRegistrationServiceError() {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException("expected service error")));
final Invocation.Builder request = resources.getJerseyTest()
@@ -219,7 +221,7 @@ class VerificationControllerTest {
@ParameterizedTest
@MethodSource
void createBeninSessionSuccess(final String requestedNumber, final String expectedNumber) {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn(
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, requestedNumber, false, null, null, null,
@@ -236,7 +238,7 @@ class VerificationControllerTest {
final ArgumentCaptor<Phonenumber.PhoneNumber> phoneNumberArgumentCaptor = ArgumentCaptor.forClass(
Phonenumber.PhoneNumber.class);
verify(registrationServiceClient).createRegistrationSession(phoneNumberArgumentCaptor.capture(), anyBoolean(), any());
verify(registrationServiceClient).createRegistrationSession(phoneNumberArgumentCaptor.capture(), anyString(), anyBoolean(), any());
final Phonenumber.PhoneNumber phoneNumber = phoneNumberArgumentCaptor.getValue();
assertEquals(expectedNumber, PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164));
@@ -260,7 +262,7 @@ class VerificationControllerTest {
.format(PhoneNumberUtil.getInstance().getExampleNumber("BJ"), PhoneNumberUtil.PhoneNumberFormat.E164);
final String oldFormatBeninE164 = newFormatBeninE164.replaceFirst("01", "");
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn(
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
@@ -281,7 +283,7 @@ class VerificationControllerTest {
@MethodSource
void createSessionSuccess(final String pushToken, final String pushTokenType,
final List<VerificationSession.Information> expectedRequestedInformation) {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn(
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
@@ -315,7 +317,7 @@ class VerificationControllerTest {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void createSessionReregistration(final boolean isReregistration) throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn(
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
@@ -337,6 +339,7 @@ class VerificationControllerTest {
verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(NUMBER, null)),
anyString(),
eq(isReregistration),
any()
);