Migrate username storage from a relational database to DynamoDB

This commit is contained in:
Jon Chambers
2021-12-01 16:50:18 -05:00
committed by GitHub
parent 0d4a3b1ad4
commit d94e86781f
23 changed files with 664 additions and 764 deletions

View File

@@ -90,7 +90,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.storage.UsernameNotAvailableException;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Hex;
@@ -141,7 +141,6 @@ class AccountControllerTest {
private static RecaptchaClient recaptchaClient = mock(RecaptchaClient.class);
private static GCMSender gcmSender = mock(GCMSender.class);
private static APNSender apnSender = mock(APNSender.class);
private static UsernamesManager usernamesManager = mock(UsernamesManager.class);
private static DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@@ -164,7 +163,6 @@ class AccountControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new AccountController(pendingAccountsManager,
accountsManager,
usernamesManager,
abusiveHostRules,
rateLimiters,
smsSender,
@@ -242,8 +240,8 @@ class AccountControllerTest {
return account;
});
when(usernamesManager.put(eq(AuthHelper.VALID_UUID), eq("n00bkiller"))).thenReturn(true);
when(usernamesManager.put(eq(AuthHelper.VALID_UUID), eq("takenusername"))).thenReturn(false);
when(accountsManager.setUsername(AuthHelper.VALID_ACCOUNT, "takenusername"))
.thenThrow(new UsernameNotAvailableException());
{
DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
@@ -293,7 +291,6 @@ class AccountControllerTest {
recaptchaClient,
gcmSender,
apnSender,
usernamesManager,
verifyExperimentEnrollmentManager);
clearInvocations(AuthHelper.DISABLED_DEVICE);
@@ -1622,7 +1619,7 @@ class AccountControllerTest {
.delete();
assertThat(response.getStatus()).isEqualTo(204);
verify(usernamesManager, times(1)).delete(eq(AuthHelper.VALID_UUID));
verify(accountsManager).clearUsername(AuthHelper.VALID_ACCOUNT);
}
@Test

View File

@@ -66,7 +66,6 @@ import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@@ -80,7 +79,6 @@ class ProfileControllerTest {
private static final Clock clock = mock(Clock.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final ProfilesManager profilesManager = mock(ProfilesManager.class);
private static final UsernamesManager usernamesManager = mock(UsernamesManager.class);
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final RateLimiter usernameRateLimiter = mock(RateLimiter.class);
@@ -107,7 +105,6 @@ class ProfileControllerTest {
rateLimiters,
accountsManager,
profilesManager,
usernamesManager,
dynamicConfigurationManager,
(acceptableLanguages, accountBadges, isSelf) -> List.of(new Badge("TEST", "other", "Test Badge",
"This badge is in unit tests.", List.of("l", "m", "h", "x", "xx", "xxx"), "SVG", List.of(new BadgeSvg("sl", "sd"), new BadgeSvg("ml", "md"), new BadgeSvg("ll", "ld")))
@@ -156,6 +153,7 @@ class ProfileControllerTest {
when(profileAccount.isAnnouncementGroupSupported()).thenReturn(false);
when(profileAccount.isChangeNumberSupported()).thenReturn(false);
when(profileAccount.getCurrentProfileVersion()).thenReturn(Optional.empty());
when(profileAccount.getUsername()).thenReturn(Optional.of("n00bkiller"));
Account capabilitiesAccount = mock(Account.class);
@@ -171,8 +169,7 @@ class ProfileControllerTest {
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(profileAccount));
when(usernamesManager.get(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of("n00bkiller"));
when(usernamesManager.get("n00bkiller")).thenReturn(Optional.of(AuthHelper.VALID_UUID_TWO));
when(accountsManager.getByUsername("n00bkiller")).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(capabilitiesAccount));
@@ -183,7 +180,6 @@ class ProfileControllerTest {
clearInvocations(rateLimiter);
clearInvocations(accountsManager);
clearInvocations(usernamesManager);
clearInvocations(usernameRateLimiter);
clearInvocations(profilesManager);
}
@@ -209,7 +205,6 @@ class ProfileControllerTest {
badge -> "Test Badge".equals(badge.getName()), "has badge with expected name"));
verify(accountsManager).getByAccountIdentifier(AuthHelper.VALID_UUID_TWO);
verify(usernamesManager, times(1)).get(eq(AuthHelper.VALID_UUID_TWO));
verify(rateLimiter, times(1)).validate(AuthHelper.VALID_UUID);
}
@@ -229,8 +224,7 @@ class ProfileControllerTest {
assertThat(profile.getBadges()).hasSize(1).element(0).has(new Condition<>(
badge -> "Test Badge".equals(badge.getName()), "has badge with expected name"));
verify(accountsManager, times(1)).getByAccountIdentifier(eq(AuthHelper.VALID_UUID_TWO));
verify(usernamesManager, times(1)).get(eq("n00bkiller"));
verify(accountsManager).getByUsername("n00bkiller");
verify(usernameRateLimiter, times(1)).validate(eq(AuthHelper.VALID_UUID));
}
@@ -265,8 +259,8 @@ class ProfileControllerTest {
assertThat(response.getStatus()).isEqualTo(404);
verify(usernamesManager, times(1)).get(eq("n00bkillerzzzzz"));
verify(usernameRateLimiter, times(1)).validate(eq(AuthHelper.VALID_UUID));
verify(accountsManager).getByUsername("n00bkillerzzzzz");
verify(usernameRateLimiter).validate(eq(AuthHelper.VALID_UUID));
}
@@ -594,7 +588,6 @@ class ProfileControllerTest {
badge -> "Test Badge".equals(badge.getName()), "has badge with expected name"));
verify(accountsManager, times(1)).getByAccountIdentifier(eq(AuthHelper.VALID_UUID_TWO));
verify(usernamesManager, times(1)).get(eq(AuthHelper.VALID_UUID_TWO));
verify(profilesManager, times(1)).get(eq(AuthHelper.VALID_UUID_TWO), eq("validversion"));
verify(rateLimiter, times(1)).validate(AuthHelper.VALID_UUID);

View File

@@ -1,630 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.time.Clock;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException;
import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
class AccountsManagerTest {
private Accounts accounts;
private DeletedAccountsManager deletedAccountsManager;
private DirectoryQueue directoryQueue;
private Keys keys;
private MessagesManager messagesManager;
private ProfilesManager profilesManager;
private RedisAdvancedClusterCommands<String, String> commands;
private AccountsManager accountsManager;
private static final Answer<?> ACCOUNT_UPDATE_ANSWER = (answer) -> {
// it is implicit in the update() contract is that a successful call will
// result in an incremented version
final Account updatedAccount = answer.getArgument(0, Account.class);
updatedAccount.setVersion(updatedAccount.getVersion() + 1);
return null;
};
@BeforeEach
void setup() throws InterruptedException {
accounts = mock(Accounts.class);
deletedAccountsManager = mock(DeletedAccountsManager.class);
directoryQueue = mock(DirectoryQueue.class);
keys = mock(Keys.class);
messagesManager = mock(MessagesManager.class);
profilesManager = mock(ProfilesManager.class);
//noinspection unchecked
commands = mock(RedisAdvancedClusterCommands.class);
doAnswer((Answer<Void>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
final UUID phoneNumberIdentifier = invocation.getArgument(2, UUID.class);
account.setNumber(number, phoneNumberIdentifier);
return null;
}).when(accounts).changeNumber(any(), anyString(), any());
doAnswer(invocation -> {
//noinspection unchecked
invocation.getArgument(1, Consumer.class).accept(Optional.empty());
return null;
}).when(deletedAccountsManager).lockAndTake(anyString(), any());
final SecureStorageClient storageClient = mock(SecureStorageClient.class);
when(storageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null));
final SecureBackupClient backupClient = mock(SecureBackupClient.class);
when(backupClient.deleteBackups(any())).thenReturn(CompletableFuture.completedFuture(null));
final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class);
final Map<String, UUID> phoneNumberIdentifiersByE164 = new HashMap<>();
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(anyString())).thenAnswer((Answer<UUID>) invocation -> {
final String number = invocation.getArgument(0, String.class);
return phoneNumberIdentifiersByE164.computeIfAbsent(number, n -> UUID.randomUUID());
});
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
RedisClusterHelper.buildMockRedisCluster(commands),
deletedAccountsManager,
directoryQueue,
keys,
messagesManager,
mock(UsernamesManager.class),
profilesManager,
mock(StoredVerificationCodeManager.class),
storageClient,
backupClient,
mock(ClientPresenceManager.class),
mock(Clock.class));
}
@Test
void testGetAccountByNumberInCache() {
UUID uuid = UUID.randomUUID();
when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(uuid.toString());
when(commands.get(eq("Account3::" + uuid))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
Optional<Account> account = accountsManager.getByE164("+14152222222");
assertTrue(account.isPresent());
assertEquals(account.get().getNumber(), "+14152222222");
assertEquals(account.get().getProfileName(), "test");
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(commands, times(1)).get(eq("AccountMap::+14152222222"));
verify(commands, times(1)).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(commands);
verifyNoInteractions(accounts);
}
@Test
void testGetAccountByUuidInCache() {
UUID uuid = UUID.randomUUID();
when(commands.get(eq("Account3::" + uuid))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
Optional<Account> account = accountsManager.getByAccountIdentifier(uuid);
assertTrue(account.isPresent());
assertEquals(account.get().getNumber(), "+14152222222");
assertEquals(account.get().getUuid(), uuid);
assertEquals(account.get().getProfileName(), "test");
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(commands, times(1)).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(commands);
verifyNoInteractions(accounts);
}
@Test
void testGetByPniInCache() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
when(commands.get(eq("AccountMap::" + pni))).thenReturn(uuid.toString());
when(commands.get(eq("Account3::" + uuid))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
Optional<Account> account = accountsManager.getByPhoneNumberIdentifier(pni);
assertTrue(account.isPresent());
assertEquals(account.get().getNumber(), "+14152222222");
assertEquals(account.get().getProfileName(), "test");
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(commands).get(eq("AccountMap::" + pni));
verify(commands).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(commands);
verifyNoInteractions(accounts);
}
@Test
void testGetAccountByNumberNotInCache() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]);
when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(null);
when(accounts.getByE164(eq("+14152222222"))).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByE164("+14152222222");
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands, times(1)).get(eq("AccountMap::+14152222222"));
verify(commands, times(1)).setex(eq("AccountMap::+14152222222"), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(accounts, times(1)).getByE164(eq("+14152222222"));
verifyNoMoreInteractions(accounts);
}
@Test
void testGetAccountByUuidNotInCache() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByAccountIdentifier(uuid);
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands, times(1)).get(eq("Account3::" + uuid));
verify(commands, times(1)).setex(eq("AccountMap::+14152222222"), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(accounts, times(1)).getByAccountIdentifier(eq(uuid));
verifyNoMoreInteractions(accounts);
}
@Test
void testGetAccountByPniNotInCache() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]);
when(commands.get(eq("AccountMap::" + pni))).thenReturn(null);
when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByPhoneNumberIdentifier(pni);
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands).get(eq("AccountMap::" + pni));
verify(commands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("AccountMap::+14152222222"), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(accounts).getByPhoneNumberIdentifier(pni);
verifyNoMoreInteractions(accounts);
}
@Test
void testGetAccountByNumberBrokenCache() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]);
when(commands.get(eq("AccountMap::+14152222222"))).thenThrow(new RedisException("Connection lost!"));
when(accounts.getByE164(eq("+14152222222"))).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByE164("+14152222222");
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands, times(1)).get(eq("AccountMap::+14152222222"));
verify(commands, times(1)).setex(eq("AccountMap::+14152222222"), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(accounts, times(1)).getByE164(eq("+14152222222"));
verifyNoMoreInteractions(accounts);
}
@Test
void testGetAccountByUuidBrokenCache() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]);
when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!"));
when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByAccountIdentifier(uuid);
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands, times(1)).get(eq("Account3::" + uuid));
verify(commands, times(1)).setex(eq("AccountMap::+14152222222"), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(accounts, times(1)).getByAccountIdentifier(eq(uuid));
verifyNoMoreInteractions(accounts);
}
@Test
void testGetAccountByPniBrokenCache() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]);
when(commands.get(eq("AccountMap::" + pni))).thenThrow(new RedisException("OH NO"));
when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByPhoneNumberIdentifier(pni);
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands).get(eq("AccountMap::" + pni));
verify(commands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("AccountMap::+14152222222"), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(accounts).getByPhoneNumberIdentifier(pni);
verifyNoMoreInteractions(accounts);
}
@Test
void testUpdate_optimisticLockingFailure() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class)
.doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accounts).update(any());
when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class)
.doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accounts).update(any());
account = accountsManager.update(account, a -> a.setProfileName("name"));
assertEquals(1, account.getVersion());
assertEquals("name", account.getProfileName());
verify(accounts, times(1)).getByAccountIdentifier(uuid);
verify(accounts, times(2)).update(any());
verifyNoMoreInteractions(accounts);
}
@Test
void testUpdate_dynamoOptimisticLockingFailureDuringCreate() {
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty())
.thenReturn(Optional.of(account));
when(accounts.create(any())).thenThrow(ContestedOptimisticLockException.class);
accountsManager.update(account, a -> {
});
verify(accounts, times(1)).update(account);
verifyNoMoreInteractions(accounts);
}
@Test
void testUpdateDevice() {
final UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(new Account("+14152222222", uuid, UUID.randomUUID(), new HashSet<>(), new byte[16])));
assertTrue(account.getDevices().isEmpty());
Device enabledDevice = new Device();
enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(new SignedPreKey(1L, "key", "signature"));
enabledDevice.setLastSeen(System.currentTimeMillis());
final long deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId);
account.addDevice(enabledDevice);
@SuppressWarnings("unchecked") Consumer<Device> deviceUpdater = mock(Consumer.class);
@SuppressWarnings("unchecked") Consumer<Device> unknownDeviceUpdater = mock(Consumer.class);
account = accountsManager.updateDevice(account, deviceId, deviceUpdater);
account = accountsManager.updateDevice(account, deviceId, d -> d.setName("deviceName"));
assertEquals("deviceName", account.getDevice(deviceId).orElseThrow().getName());
verify(deviceUpdater, times(1)).accept(any(Device.class));
accountsManager.updateDevice(account, account.getNextDeviceId(), unknownDeviceUpdater);
verify(unknownDeviceUpdater, never()).accept(any(Device.class));
}
@Test
void testCreateFreshAccount() throws InterruptedException {
when(accounts.create(any())).thenReturn(true);
final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null);
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
verify(accounts).create(argThat(account -> e164.equals(account.getNumber())));
verifyNoInteractions(keys);
verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager);
}
@Test
void testReregisterAccount() throws InterruptedException {
final UUID existingUuid = UUID.randomUUID();
when(accounts.create(any())).thenAnswer(invocation -> {
invocation.getArgument(0, Account.class).setUuid(existingUuid);
return false;
});
final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null);
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
verify(accounts).create(
argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())));
verify(keys).delete(existingUuid);
verify(messagesManager).clear(existingUuid);
verify(profilesManager).deleteAll(existingUuid);
}
@Test
void testCreateAccountRecentlyDeleted() throws InterruptedException {
final UUID recentlyDeletedUuid = UUID.randomUUID();
doAnswer(invocation -> {
//noinspection unchecked
invocation.getArgument(1, Consumer.class).accept(Optional.of(recentlyDeletedUuid));
return null;
}).when(deletedAccountsManager).lockAndTake(anyString(), any());
when(accounts.create(any())).thenReturn(true);
final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null);
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
verify(accounts).create(
argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid())));
verifyNoInteractions(keys);
verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testCreateWithDiscoverability(final boolean discoverable) throws InterruptedException {
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, discoverable, null);
final Account account = accountsManager.create("+18005550123", "password", null, attributes, new ArrayList<>());
assertEquals(discoverable, account.isDiscoverableByPhoneNumber());
if (!discoverable) {
verify(directoryQueue).deleteAccount(account);
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testCreateWithStorageCapability(final boolean hasStorage) throws InterruptedException {
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true,
new DeviceCapabilities(false, false, false, hasStorage, false, false, false, false, false));
final Account account = accountsManager.create("+18005550123", "password", null, attributes, new ArrayList<>());
assertEquals(hasStorage, account.isStorageSupported());
}
@ParameterizedTest
@MethodSource
void testUpdateDirectoryQueue(final boolean visibleBeforeUpdate, final boolean visibleAfterUpdate,
final boolean expectRefresh) {
final Account account = new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]);
// this sets up the appropriate result for Account#shouldBeVisibleInDirectory
final Device device = new Device(Device.MASTER_ID, "device", "token", "salt", null, null, null, true, 1,
new SignedPreKey(1, "key", "sig"), 0, 0,
"OWT", 0, new DeviceCapabilities());
account.addDevice(device);
account.setDiscoverableByPhoneNumber(visibleBeforeUpdate);
final Account updatedAccount = accountsManager.update(account,
a -> a.setDiscoverableByPhoneNumber(visibleAfterUpdate));
verify(directoryQueue, times(expectRefresh ? 1 : 0)).refreshAccount(updatedAccount);
}
@SuppressWarnings("unused")
private static Stream<Arguments> testUpdateDirectoryQueue() {
return Stream.of(
Arguments.of(false, false, false),
Arguments.of(true, true, false),
Arguments.of(false, true, true),
Arguments.of(true, false, true));
}
@ParameterizedTest
@MethodSource
void testUpdateDeviceLastSeen(final boolean expectUpdate, final long initialLastSeen, final long updatedLastSeen) {
final Account account = new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]);
final Device device = new Device(Device.MASTER_ID, "device", "token", "salt", null, null, null, true, 1,
new SignedPreKey(1, "key", "sig"), initialLastSeen, 0,
"OWT", 0, new DeviceCapabilities());
account.addDevice(device);
accountsManager.updateDeviceLastSeen(account, device, updatedLastSeen);
assertEquals(expectUpdate ? updatedLastSeen : initialLastSeen, device.getLastSeen());
verify(accounts, expectUpdate ? times(1) : never()).update(account);
}
@SuppressWarnings("unused")
private static Stream<Arguments> testUpdateDeviceLastSeen() {
return Stream.of(
Arguments.of(true, 1, 2),
Arguments.of(false, 1, 1),
Arguments.of(false, 2, 1)
);
}
@Test
void testChangePhoneNumber() throws InterruptedException {
doAnswer(invocation -> invocation.getArgument(2, Supplier.class).get())
.when(deletedAccountsManager).lockAndPut(anyString(), anyString(), any());
final String originalNumber = "+14152222222";
final String targetNumber = "+14153333333";
final UUID uuid = UUID.randomUUID();
Account account = new Account(originalNumber, uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber);
assertEquals(targetNumber, account.getNumber());
verify(directoryQueue).changePhoneNumber(argThat(a -> a.getUuid().equals(uuid)), eq(originalNumber), eq(targetNumber));
}
@Test
void testChangePhoneNumberSameNumber() throws InterruptedException {
final String number = "+14152222222";
Account account = new Account(number, UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]);
account = accountsManager.changeNumber(account, number);
assertEquals(number, account.getNumber());
verify(deletedAccountsManager, never()).lockAndPut(anyString(), anyString(), any());
verify(directoryQueue, never()).changePhoneNumber(any(), any(), any());
}
@Test
void testChangePhoneNumberExistingAccount() throws InterruptedException {
doAnswer(invocation -> invocation.getArgument(2, Supplier.class).get())
.when(deletedAccountsManager).lockAndPut(anyString(), anyString(), any());
final String originalNumber = "+14152222222";
final String targetNumber = "+14153333333";
final UUID existingAccountUuid = UUID.randomUUID();
final UUID uuid = UUID.randomUUID();
final Account existingAccount = new Account(targetNumber, existingAccountUuid, UUID.randomUUID(), new HashSet<>(), new byte[16]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
Account account = new Account(originalNumber, uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber);
assertEquals(targetNumber, account.getNumber());
verify(directoryQueue).changePhoneNumber(argThat(a -> a.getUuid().equals(uuid)), eq(originalNumber), eq(targetNumber));
verify(directoryQueue, never()).deleteAccount(any());
}
@Test
void testChangePhoneNumberViaUpdate() {
final String originalNumber = "+14152222222";
final String targetNumber = "+14153333333";
final UUID uuid = UUID.randomUUID();
final Account account = new Account(originalNumber, uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]);
assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setNumber(targetNumber, UUID.randomUUID())));
}
}

View File

@@ -1,185 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.storage;
import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import org.junit.Test;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.ReservedUsernames;
import org.whispersystems.textsecuregcm.storage.Usernames;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import java.util.Optional;
import java.util.UUID;
import static junit.framework.TestCase.assertSame;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
public class UsernamesManagerTest {
@Test
public void testGetByUsernameInCache() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Usernames usernames = mock(Usernames.class);
ReservedUsernames reserved = mock(ReservedUsernames.class);
UUID uuid = UUID.randomUUID();
when(commands.get(eq("UsernameByUsername::n00bkiller"))).thenReturn(uuid.toString());
UsernamesManager usernamesManager = new UsernamesManager(usernames, reserved, cacheCluster);
Optional<UUID> retrieved = usernamesManager.get("n00bkiller");
assertTrue(retrieved.isPresent());
assertEquals(retrieved.get(), uuid);
verify(commands, times(1)).get(eq("UsernameByUsername::n00bkiller"));
verifyNoMoreInteractions(commands);
verifyNoMoreInteractions(usernames);
}
@Test
public void testGetByUuidInCache() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Usernames usernames = mock(Usernames.class);
ReservedUsernames reserved = mock(ReservedUsernames.class);
UUID uuid = UUID.randomUUID();
when(commands.get(eq("UsernameByUuid::" + uuid.toString()))).thenReturn("n00bkiller");
UsernamesManager usernamesManager = new UsernamesManager(usernames, reserved, cacheCluster);
Optional<String> retrieved = usernamesManager.get(uuid);
assertTrue(retrieved.isPresent());
assertEquals(retrieved.get(), "n00bkiller");
verify(commands, times(1)).get(eq("UsernameByUuid::" + uuid.toString()));
verifyNoMoreInteractions(commands);
verifyNoMoreInteractions(usernames);
}
@Test
public void testGetByUsernameNotInCache() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Usernames usernames = mock(Usernames.class);
ReservedUsernames reserved = mock(ReservedUsernames.class);
UUID uuid = UUID.randomUUID();
when(commands.get(eq("UsernameByUsername::n00bkiller"))).thenReturn(null);
when(usernames.get(eq("n00bkiller"))).thenReturn(Optional.of(uuid));
UsernamesManager usernamesManager = new UsernamesManager(usernames, reserved, cacheCluster);
Optional<UUID> retrieved = usernamesManager.get("n00bkiller");
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), uuid);
verify(commands, times(1)).get(eq("UsernameByUsername::n00bkiller"));
verify(commands, times(1)).set(eq("UsernameByUsername::n00bkiller"), eq(uuid.toString()));
verify(commands, times(1)).set(eq("UsernameByUuid::" + uuid.toString()), eq("n00bkiller"));
verify(commands, times(1)).get(eq("UsernameByUuid::" + uuid.toString()));
verifyNoMoreInteractions(commands);
verify(usernames, times(1)).get(eq("n00bkiller"));
verifyNoMoreInteractions(usernames);
}
@Test
public void testGetByUuidNotInCache() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Usernames usernames = mock(Usernames.class);
ReservedUsernames reserved = mock(ReservedUsernames.class);
UUID uuid = UUID.randomUUID();
when(commands.get(eq("UsernameByUuid::" + uuid.toString()))).thenReturn(null);
when(usernames.get(eq(uuid))).thenReturn(Optional.of("n00bkiller"));
UsernamesManager usernamesManager = new UsernamesManager(usernames, reserved, cacheCluster);
Optional<String> retrieved = usernamesManager.get(uuid);
assertTrue(retrieved.isPresent());
assertEquals(retrieved.get(), "n00bkiller");
verify(commands, times(2)).get(eq("UsernameByUuid::" + uuid));
verify(commands, times(1)).set(eq("UsernameByUuid::" + uuid), eq("n00bkiller"));
verify(commands, times(1)).set(eq("UsernameByUsername::n00bkiller"), eq(uuid.toString()));
verifyNoMoreInteractions(commands);
verify(usernames, times(1)).get(eq(uuid));
verifyNoMoreInteractions(usernames);
}
@Test
public void testGetByUsernameBrokenCache() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Usernames usernames = mock(Usernames.class);
ReservedUsernames reserved = mock(ReservedUsernames.class);
UUID uuid = UUID.randomUUID();
when(commands.get(eq("UsernameByUsername::n00bkiller"))).thenThrow(new RedisException("Connection lost!"));
when(usernames.get(eq("n00bkiller"))).thenReturn(Optional.of(uuid));
UsernamesManager usernamesManager = new UsernamesManager(usernames, reserved, cacheCluster);
Optional<UUID> retrieved = usernamesManager.get("n00bkiller");
assertTrue(retrieved.isPresent());
assertEquals(retrieved.get(), uuid);
verify(commands, times(1)).get(eq("UsernameByUsername::n00bkiller"));
verify(commands, times(1)).set(eq("UsernameByUsername::n00bkiller"), eq(uuid.toString()));
verify(commands, times(1)).set(eq("UsernameByUuid::" + uuid.toString()), eq("n00bkiller"));
verify(commands, times(1)).get(eq("UsernameByUuid::" + uuid.toString()));
verifyNoMoreInteractions(commands);
verify(usernames, times(1)).get(eq("n00bkiller"));
verifyNoMoreInteractions(usernames);
}
@Test
public void testGetAccountByUuidBrokenCache() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Usernames usernames = mock(Usernames.class);
ReservedUsernames reserved = mock(ReservedUsernames.class);
UUID uuid = UUID.randomUUID();
when(commands.get(eq("UsernameByUuid::" + uuid))).thenThrow(new RedisException("Connection lost!"));
when(usernames.get(eq(uuid))).thenReturn(Optional.of("n00bkiller"));
UsernamesManager usernamesManager = new UsernamesManager(usernames, reserved, cacheCluster);
Optional<String> retrieved = usernamesManager.get(uuid);
assertTrue(retrieved.isPresent());
assertEquals(retrieved.get(), "n00bkiller");
verify(commands, times(2)).get(eq("UsernameByUuid::" + uuid));
verifyNoMoreInteractions(commands);
verify(usernames, times(1)).get(eq(uuid));
verifyNoMoreInteractions(usernames);
}
}

View File

@@ -1,169 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.storage;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule;
import org.jdbi.v3.core.Jdbi;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
import org.whispersystems.textsecuregcm.storage.Usernames;
import java.io.IOException;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Optional;
import java.util.UUID;
import static junit.framework.TestCase.assertTrue;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.Assert.assertFalse;
public class UsernamesTest {
@Rule
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
private Usernames usernames;
@Before
public void setupAccountsDao() {
FaultTolerantDatabase faultTolerantDatabase = new FaultTolerantDatabase("usernamesTest",
Jdbi.create(db.getTestDatabase()),
new CircuitBreakerConfiguration());
this.usernames = new Usernames(faultTolerantDatabase);
}
@Test
public void testPut() throws SQLException, IOException {
UUID uuid = UUID.randomUUID();
String username = "myusername";
assertTrue(usernames.put(uuid, username));
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM usernames WHERE uuid = ?");
verifyStoredState(statement, uuid, username);
}
@Test
public void testPutChange() throws SQLException, IOException {
UUID uuid = UUID.randomUUID();
String firstUsername = "myfirstusername";
String secondUsername = "mysecondusername";
assertTrue(usernames.put(uuid, firstUsername));
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM usernames WHERE uuid = ?");
verifyStoredState(statement, uuid, firstUsername);
assertTrue(usernames.put(uuid, secondUsername));
verifyStoredState(statement, uuid, secondUsername);
}
@Test
public void testPutConflict() throws SQLException {
UUID firstUuid = UUID.randomUUID();
UUID secondUuid = UUID.randomUUID();
String username = "myfirstusername";
assertTrue(usernames.put(firstUuid, username));
assertFalse(usernames.put(secondUuid, username));
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM usernames WHERE username = ?");
statement.setString(1, username);
ResultSet resultSet = statement.executeQuery();
assertTrue(resultSet.next());
assertThat(resultSet.getString("uuid")).isEqualTo(firstUuid.toString());
assertThat(resultSet.next()).isFalse();
}
@Test
public void testGetByUuid() {
UUID uuid = UUID.randomUUID();
String username = "myusername";
assertTrue(usernames.put(uuid, username));
Optional<String> retrieved = usernames.get(uuid);
assertTrue(retrieved.isPresent());
assertThat(retrieved.get()).isEqualTo(username);
}
@Test
public void testGetByUuidMissing() {
Optional<String> retrieved = usernames.get(UUID.randomUUID());
assertFalse(retrieved.isPresent());
}
@Test
public void testGetByUsername() {
UUID uuid = UUID.randomUUID();
String username = "myusername";
assertTrue(usernames.put(uuid, username));
Optional<UUID> retrieved = usernames.get(username);
assertTrue(retrieved.isPresent());
assertThat(retrieved.get()).isEqualTo(uuid);
}
@Test
public void testGetByUsernameMissing() {
Optional<UUID> retrieved = usernames.get("myusername");
assertFalse(retrieved.isPresent());
}
@Test
public void testDelete() {
UUID uuid = UUID.randomUUID();
String username = "myusername";
assertTrue(usernames.put(uuid, username));
Optional<UUID> retrieved = usernames.get(username);
assertTrue(retrieved.isPresent());
assertThat(retrieved.get()).isEqualTo(uuid);
usernames.delete(uuid);
assertThat(usernames.get(uuid).isPresent()).isFalse();
}
private void verifyStoredState(PreparedStatement statement, UUID uuid, String expectedUsername)
throws SQLException, IOException
{
statement.setObject(1, uuid);
ResultSet resultSet = statement.executeQuery();
if (resultSet.next()) {
String data = resultSet.getString("username");
assertThat(data).isNotEmpty();
assertThat(data).isEqualTo(expectedUsername);
} else {
throw new AssertionError("No data");
}
assertThat(resultSet.next()).isFalse();
}
}

View File

@@ -103,6 +103,10 @@ public class AccountsHelper {
when(updatedAccount.getNumber()).thenAnswer(stubbing);
break;
}
case "getUsername": {
when(updatedAccount.getUsername()).thenAnswer(stubbing);
break;
}
case "getDevices": {
when(updatedAccount.getDevices())
.thenAnswer(stubbing);