Store and retrieve one-time pre-keys by UUID

This commit is contained in:
Jon Chambers
2021-11-03 18:21:40 -04:00
committed by Jon Chambers
parent 5e1334e8de
commit 3a4c5a2bfb
4 changed files with 87 additions and 126 deletions

View File

@@ -5,135 +5,99 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.entities.PreKey;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class KeysTest {
private Account account;
private Keys keys;
@ClassRule
public static KeysDynamoDbRule dynamoDbRule = new KeysDynamoDbRule();
private static final String ACCOUNT_NUMBER = "+18005551234";
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
@Before
public void setup() {
keys = new Keys(dynamoDbRule.getDynamoDbClient(), KeysDynamoDbRule.TABLE_NAME);
account = mock(Account.class);
when(account.getNumber()).thenReturn(ACCOUNT_NUMBER);
when(account.getUuid()).thenReturn(UUID.randomUUID());
}
@Test
public void testStore() {
assertEquals("Initial pre-key count for an account should be zero",
0, keys.getCount(account, DEVICE_ID));
0, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals(1, keys.getCount(account, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key")));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals("Repeatedly storing same key should have no effect",
1, keys.getCount(account, DEVICE_ID));
1, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(account, DEVICE_ID, List.of(new PreKey(2, "different-public-key")));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key")));
assertEquals("Inserting a new key should overwrite all prior keys for the given account/device",
1, keys.getCount(account, DEVICE_ID));
1, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(account, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key")));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key")));
assertEquals("Inserting multiple new keys should overwrite all prior keys for the given account/device",
2, keys.getCount(account, DEVICE_ID));
}
@Test
public void testTakeAccount() {
final Device firstDevice = mock(Device.class);
final Device secondDevice = mock(Device.class);
when(firstDevice.getId()).thenReturn(DEVICE_ID);
when(secondDevice.getId()).thenReturn(DEVICE_ID + 1);
when(account.getDevices()).thenReturn(Set.of(firstDevice, secondDevice));
assertEquals(Collections.emptyMap(), keys.take(account));
final PreKey firstDevicePreKey = new PreKey(1, "public-key");
final PreKey secondDevicePreKey = new PreKey(2, "second-key");
keys.store(account, DEVICE_ID, List.of(firstDevicePreKey));
keys.store(account, DEVICE_ID + 1, List.of(secondDevicePreKey));
final Map<Long, PreKey> expectedKeys = Map.of(DEVICE_ID, firstDevicePreKey,
DEVICE_ID + 1, secondDevicePreKey);
assertEquals(expectedKeys, keys.take(account));
assertEquals(0, keys.getCount(account, DEVICE_ID));
assertEquals(0, keys.getCount(account, DEVICE_ID + 1));
2, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
public void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keys.take(account, DEVICE_ID));
assertEquals(Optional.empty(), keys.take(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = new PreKey(1, "public-key");
keys.store(account, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key")));
assertEquals(Optional.of(preKey), keys.take(account, DEVICE_ID));
assertEquals(1, keys.getCount(account, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key")));
assertEquals(Optional.of(preKey), keys.take(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
public void testGetCount() {
assertEquals(0, keys.getCount(account, DEVICE_ID));
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals(1, keys.getCount(account, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
public void testDeleteByAccount() {
keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")));
keys.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device")));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device")));
assertEquals(2, keys.getCount(account, DEVICE_ID));
assertEquals(1, keys.getCount(account, DEVICE_ID + 1));
assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1));
keys.delete(account.getUuid());
keys.delete(ACCOUNT_UUID);
assertEquals(0, keys.getCount(account, DEVICE_ID));
assertEquals(0, keys.getCount(account, DEVICE_ID + 1));
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1));
}
@Test
public void testDeleteByAccountAndDevice() {
keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")));
keys.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device")));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device")));
assertEquals(2, keys.getCount(account, DEVICE_ID));
assertEquals(1, keys.getCount(account, DEVICE_ID + 1));
assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1));
keys.delete(account.getUuid(), DEVICE_ID);
keys.delete(ACCOUNT_UUID, DEVICE_ID);
assertEquals(0, keys.getCount(account, DEVICE_ID));
assertEquals(1, keys.getCount(account, DEVICE_ID + 1));
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1));
}
@Test

View File

@@ -15,7 +15,6 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.AccountsHelper.eqUuid;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
@@ -26,7 +25,6 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
@@ -142,6 +140,7 @@ class KeysControllerTest {
when(sampleDevice3.getId()).thenReturn(3L);
when(sampleDevice4.getId()).thenReturn(4L);
when(existsAccount.getUuid()).thenReturn(EXISTS_UUID);
when(existsAccount.getDevice(1L)).thenReturn(Optional.of(sampleDevice));
when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3));
@@ -161,14 +160,9 @@ class KeysControllerTest {
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.take(eq(existsAccount), eq(1L))).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.take(existsAccount)).thenReturn(Map.of(1L, SAMPLE_KEY,
2L, SAMPLE_KEY2,
3L, SAMPLE_KEY3,
4L, SAMPLE_KEY4));
when(KEYS.getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L))).thenReturn(5);
when(KEYS.getCount(AuthHelper.VALID_UUID, 1)).thenReturn(5);
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
@@ -198,7 +192,7 @@ class KeysControllerTest {
assertThat(result.getCount()).isEqualTo(4);
verify(KEYS).getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L));
verify(KEYS).getCount(AuthHelper.VALID_UUID, 1);
}
@@ -257,7 +251,7 @@ class KeysControllerTest {
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey());
verify(KEYS).take(eq(existsAccount), eq(1L));
verify(KEYS).take(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS);
}
@@ -275,7 +269,7 @@ class KeysControllerTest {
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey());
verify(KEYS).take(eq(existsAccount), eq(1L));
verify(KEYS).take(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS);
}
@@ -321,8 +315,13 @@ class KeysControllerTest {
@Test
void validMultiRequestTestV2() {
when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.take(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2));
when(KEYS.take(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3));
when(KEYS.take(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4));
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
@@ -332,8 +331,8 @@ class KeysControllerTest {
PreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
@@ -365,7 +364,10 @@ class KeysControllerTest {
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
verify(KEYS).take(eq(existsAccount));
verify(KEYS).take(EXISTS_UUID, 1);
verify(KEYS).take(EXISTS_UUID, 2);
verify(KEYS).take(EXISTS_UUID, 3);
verify(KEYS).take(EXISTS_UUID, 4);
verifyNoMoreInteractions(KEYS);
}
@@ -433,8 +435,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eqUuid(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture());
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@@ -467,8 +469,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eqUuid(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture());
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);