Handle unexpectedly missing last-resort prekeys

This commit is contained in:
Ravi Khadiwala
2025-07-14 11:43:21 -05:00
committed by ravi-signal
parent bf779f30ab
commit 37d67f110a
2 changed files with 90 additions and 33 deletions

View File

@@ -50,6 +50,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.IdentityKey;
@@ -109,6 +110,7 @@ class KeysControllerTest {
private static final int SAMPLE_REGISTRATION_ID4 = 1555;
private static final int SAMPLE_PNI_REGISTRATION_ID = 1717;
private static final int SAMPLE_PNI_REGISTRATION_ID2 = 1718;
private final ECKeyPair IDENTITY_KEY_PAIR = ECKeyPair.generate();
private final IdentityKey IDENTITY_KEY = new IdentityKey(IDENTITY_KEY_PAIR.getPublicKey());
@@ -162,6 +164,7 @@ class KeysControllerTest {
.build();
private Device sampleDevice;
private Device sampleDevice2;
private record WeaklyTypedPreKey(long keyId,
@@ -199,8 +202,8 @@ class KeysControllerTest {
void setup() {
clock.unpin();
sampleDevice = mock(Device.class);
final Device sampleDevice2 = mock(Device.class);
sampleDevice = mock(Device.class);
sampleDevice2 = mock(Device.class);
final Device sampleDevice3 = mock(Device.class);
final Device sampleDevice4 = mock(Device.class);
@@ -218,6 +221,7 @@ class KeysControllerTest {
when(sampleDevice3.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice4.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID4);
when(sampleDevice.getRegistrationId(IdentityType.PNI)).thenReturn(SAMPLE_PNI_REGISTRATION_ID);
when(sampleDevice2.getRegistrationId(IdentityType.PNI)).thenReturn(SAMPLE_PNI_REGISTRATION_ID2);
when(sampleDevice.getId()).thenReturn(sampleDeviceId);
when(sampleDevice2.getId()).thenReturn(sampleDevice2Id);
when(sampleDevice3.getId()).thenReturn(sampleDevice3Id);
@@ -357,30 +361,6 @@ class KeysControllerTest {
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqTestNoPqKeysV2() {
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqTestV2() {
PreKeyResponse result = resources.getJerseyTest()
@@ -482,6 +462,59 @@ class KeysControllerTest {
verifyNoMoreInteractions(KEYS);
}
private enum RereadBehavior {
ACCOUNT_MISSING,
DEVICE_MISSING,
REG_ID_CHANGED,
PRESENT
}
@ParameterizedTest
@EnumSource(RereadBehavior.class)
void testGetKeysMissingLastResort(RereadBehavior rereadBehavior) {
when(KEYS.takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID2))
.thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2)));
when(KEYS.takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID2))
// Missing PQ key
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
switch (rereadBehavior) {
case ACCOUNT_MISSING -> when(accounts.getByServiceIdentifierAsync(new PniServiceIdentifier(EXISTS_PNI)))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
case DEVICE_MISSING -> when(existsAccount.getDevice(SAMPLE_DEVICE_ID2))
.thenReturn(Optional.empty());
case REG_ID_CHANGED -> when(sampleDevice2.getRegistrationId(IdentityType.PNI))
.thenReturn(SAMPLE_PNI_REGISTRATION_ID2)
.thenReturn(SAMPLE_PNI_REGISTRATION_ID2 + 1);
case PRESENT -> {
}
}
when(existsAccount.getDevices()).thenReturn(List.of(sampleDevice, sampleDevice2));
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/PNI:%s/*", EXISTS_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
if (rereadBehavior == RereadBehavior.PRESENT) {
// The device was missing a last resort prekey which should be impossible
assertThat(response.getStatus()).isEqualTo(500);
} else {
// In the other cases, the device plausibly disappeared so we can just leave that device out
final PreKeyResponse result = response.readEntity(PreKeyResponse.class);
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY);
}
}
@ParameterizedTest
@MethodSource
void testGetKeysWithGroupSendEndorsement(
@@ -675,6 +708,8 @@ class KeysControllerTest {
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY4)));
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
@@ -717,7 +752,7 @@ class KeysControllerTest {
deviceId = results.getDevice(SAMPLE_DEVICE_ID4).getDeviceId();
assertEquals(SAMPLE_KEY4, preKey);
assertThat(pqPreKey).isNull();
assertThat(pqPreKey).isEqualTo(SAMPLE_PQ_KEY4);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID4);