Add optimistic locking to account updates

This commit is contained in:
Chris Eager
2021-07-07 11:54:22 -05:00
committed by Jon Chambers
parent 62022c7de1
commit 158d65c6a7
30 changed files with 1397 additions and 399 deletions

View File

@@ -50,6 +50,7 @@ import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.ScanResponse;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
class AccountsDynamoDbTest {
@@ -211,6 +212,10 @@ class AccountsDynamoDbTest {
verifyStoredState("+14151112222", account.getUuid(), account);
account.setProfileName("name");
accountsDynamoDb.update(account);
UUID secondUuid = UUID.randomUUID();
device = generateDevice(1);
@@ -252,13 +257,44 @@ class AccountsDynamoDbTest {
assertThatThrownBy(() -> accountsDynamoDb.update(unknownAccount)).isInstanceOfAny(ConditionalCheckFailedException.class);
account.setDynamoDbMigrationVersion(5);
account.setProfileName("name");
accountsDynamoDb.update(account);
assertThat(account.getVersion()).isEqualTo(2);
verifyStoredState("+14151112222", account.getUuid(), account);
account.setVersion(1);
assertThatThrownBy(() -> accountsDynamoDb.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class);
account.setVersion(2);
account.setProfileName("name2");
accountsDynamoDb.update(account);
verifyStoredState("+14151112222", account.getUuid(), account);
}
@Test
void testUpdateWithMockTransactionConflictException() {
final DynamoDbClient dynamoDbClient = mock(DynamoDbClient.class);
accountsDynamoDb = new AccountsDynamoDb(dynamoDbClient, mock(DynamoDbAsyncClient.class),
new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingDeque<>()),
dynamoDbExtension.getTableName(), NUMBERS_TABLE_NAME, mock(MigrationDeletedAccounts.class),
mock(MigrationRetryAccounts.class));
when(dynamoDbClient.updateItem(any(UpdateItemRequest.class)))
.thenThrow(TransactionConflictException.class);
Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", UUID.randomUUID(), Collections.singleton(device));
assertThatThrownBy(() -> accountsDynamoDb.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class);
}
@Test
void testRetrieveFrom() {
List<Account> users = new ArrayList<>();
@@ -463,7 +499,7 @@ class AccountsDynamoDbTest {
assertThat(migrated).isFalse();
verifyStoredState("+14151112222", firstUuid, account);
account.setDynamoDbMigrationVersion(account.getDynamoDbMigrationVersion() + 1);
account.setVersion(account.getVersion() + 1);
migrated = accountsDynamoDb.migrate(account).get();
@@ -504,8 +540,8 @@ class AccountsDynamoDbTest {
String data = new String(get.item().get(AccountsDynamoDb.ATTR_ACCOUNT_DATA).b().asByteArray(), StandardCharsets.UTF_8);
assertThat(data).isNotEmpty();
assertThat(AttributeValues.getInt(get.item(), AccountsDynamoDb.ATTR_MIGRATION_VERSION, -1))
.isEqualTo(expecting.getDynamoDbMigrationVersion());
assertThat(AttributeValues.getInt(get.item(), AccountsDynamoDb.ATTR_VERSION, -1))
.isEqualTo(expecting.getVersion());
Account result = AccountsDynamoDb.fromItem(get.item());
verifyStoredState(number, uuid, result, expecting);
@@ -518,6 +554,7 @@ class AccountsDynamoDbTest {
assertThat(result.getNumber()).isEqualTo(number);
assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen());
assertThat(result.getUuid()).isEqualTo(uuid);
assertThat(result.getVersion()).isEqualTo(expecting.getVersion());
assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue();
for (Device expectingDevice : expecting.getDevices()) {

View File

@@ -0,0 +1,274 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit5.EmbeddedPostgresExtension;
import com.opentable.db.postgres.junit5.PreparedDbExtension;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.jdbi.v3.core.Jdbi;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAccountsDynamoDbMigrationConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.tests.util.JsonHelpers;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
import software.amazon.awssdk.services.dynamodb.model.KeyType;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
class AccountsManagerConcurrentModificationIntegrationTest {
@RegisterExtension
static PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
private static final String ACCOUNTS_TABLE_NAME = "accounts_test";
private static final String NUMBERS_TABLE_NAME = "numbers_test";
@RegisterExtension
static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.hashKey(AccountsDynamoDb.KEY_ACCOUNT_UUID)
.attributeDefinition(AttributeDefinition.builder()
.attributeName(AccountsDynamoDb.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build())
.build();
private Accounts accounts;
private AccountsDynamoDb accountsDynamoDb;
private AccountsManager accountsManager;
private RedisAdvancedClusterCommands<String, String> commands;
private Executor mutationExecutor = new ThreadPoolExecutor(20, 20, 5, TimeUnit.SECONDS, new LinkedBlockingDeque<>(20));
@BeforeEach
void setup() {
{
CreateTableRequest createNumbersTableRequest = CreateTableRequest.builder()
.tableName(NUMBERS_TABLE_NAME)
.keySchema(KeySchemaElement.builder()
.attributeName(AccountsDynamoDb.ATTR_ACCOUNT_E164)
.keyType(KeyType.HASH)
.build())
.attributeDefinitions(AttributeDefinition.builder()
.attributeName(AccountsDynamoDb.ATTR_ACCOUNT_E164)
.attributeType(ScalarAttributeType.S)
.build())
.provisionedThroughput(DynamoDbExtension.DEFAULT_PROVISIONED_THROUGHPUT)
.build();
dynamoDbExtension.getDynamoDbClient().createTable(createNumbersTableRequest);
}
accountsDynamoDb = new AccountsDynamoDb(
dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getDynamoDbAsyncClient(),
new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingDeque<>()),
dynamoDbExtension.getTableName(),
NUMBERS_TABLE_NAME,
mock(MigrationDeletedAccounts.class),
mock(MigrationRetryAccounts.class));
{
final CircuitBreakerConfiguration circuitBreakerConfiguration = new CircuitBreakerConfiguration();
circuitBreakerConfiguration.setIgnoredExceptions(List.of("org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException"));
FaultTolerantDatabase faultTolerantDatabase = new FaultTolerantDatabase("accountsTest",
Jdbi.create(db.getTestDatabase()),
circuitBreakerConfiguration);
accounts = new Accounts(faultTolerantDatabase);
}
{
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
final DynamicAccountsDynamoDbMigrationConfiguration config = dynamicConfiguration
.getAccountsDynamoDbMigrationConfiguration();
config.setDeleteEnabled(true);
config.setReadEnabled(true);
config.setWriteEnabled(true);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), anyString())).thenReturn(true);
commands = mock(RedisAdvancedClusterCommands.class);
accountsManager = new AccountsManager(
accounts,
accountsDynamoDb,
RedisClusterHelper.buildMockRedisCluster(commands),
mock(DeletedAccounts.class),
mock(DirectoryQueue.class),
mock(KeysDynamoDb.class),
mock(MessagesManager.class),
mock(UsernamesManager.class),
mock(ProfilesManager.class),
mock(SecureStorageClient.class),
mock(SecureBackupClient.class),
experimentEnrollmentManager,
dynamicConfigurationManager);
}
}
@Test
void testConcurrentUpdate() throws IOException {
final UUID uuid = UUID.randomUUID();
accountsManager.create(generateAccount("+14155551212", uuid));
final String profileName = "name";
final String avatar = "avatar";
final boolean discoverableByPhoneNumber = false;
final String currentProfileVersion = "cpv";
final String identityKey = "ikey";
final byte[] unidentifiedAccessKey = new byte[]{1};
final String pin = "1234";
final String registrationLock = "reglock";
final AuthenticationCredentials credentials = new AuthenticationCredentials(registrationLock);
final boolean unrestrictedUnidentifiedAccess = true;
final long lastSeen = Instant.now().getEpochSecond();
CompletableFuture.allOf(
modifyAccount(uuid, account -> account.setProfileName(profileName)),
modifyAccount(uuid, account -> account.setAvatar(avatar)),
modifyAccount(uuid, account -> account.setDiscoverableByPhoneNumber(discoverableByPhoneNumber)),
modifyAccount(uuid, account -> account.setCurrentProfileVersion(currentProfileVersion)),
modifyAccount(uuid, account -> account.setIdentityKey(identityKey)),
modifyAccount(uuid, account -> account.setUnidentifiedAccessKey(unidentifiedAccessKey)),
modifyAccount(uuid, account -> account.setPin(pin)),
modifyAccount(uuid, account -> account.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt())),
modifyAccount(uuid, account -> account.setUnrestrictedUnidentifiedAccess(unrestrictedUnidentifiedAccess)),
modifyDevice(uuid, Device.MASTER_ID, device-> device.setLastSeen(lastSeen)),
modifyDevice(uuid, Device.MASTER_ID, device-> device.setName("deviceName"))
).join();
final Account managerAccount = accountsManager.get(uuid).get();
final Account dbAccount = accounts.get(uuid).get();
final Account dynamoAccount = accountsDynamoDb.get(uuid).get();
final Account redisAccount = getLastAccountFromRedisMock(commands);
Stream.of(
new Pair<>("manager", managerAccount),
new Pair<>("db", dbAccount),
new Pair<>("dynamo", dynamoAccount),
new Pair<>("redis", redisAccount)
).forEach(pair ->
verifyAccount(pair.first(), pair.second(), profileName, avatar, discoverableByPhoneNumber,
currentProfileVersion, identityKey, unidentifiedAccessKey, pin, registrationLock,
unrestrictedUnidentifiedAccess, lastSeen)
);
}
private Account getLastAccountFromRedisMock(RedisAdvancedClusterCommands<String, String> commands) throws IOException {
ArgumentCaptor<String> redisSetArgumentCapture = ArgumentCaptor.forClass(String.class);
verify(commands, atLeast(20)).set(anyString(), redisSetArgumentCapture.capture());
return JsonHelpers.fromJson(redisSetArgumentCapture.getValue(), Account.class);
}
private void verifyAccount(final String name, final Account account, final String profileName, final String avatar, final boolean discoverableByPhoneNumber, final String currentProfileVersion, final String identityKey, final byte[] unidentifiedAccessKey, final String pin, final String clientRegistrationLock, final boolean unrestrictedUnidentifiedAcces, final long lastSeen) {
assertAll(name,
() -> assertEquals(profileName, account.getProfileName()),
() -> assertEquals(avatar, account.getAvatar()),
() -> assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber()),
() -> assertEquals(currentProfileVersion, account.getCurrentProfileVersion().get()),
() -> assertEquals(identityKey, account.getIdentityKey()),
() -> assertArrayEquals(unidentifiedAccessKey, account.getUnidentifiedAccessKey().get()),
() -> assertTrue(account.getRegistrationLock().verify(clientRegistrationLock, pin)),
() -> assertEquals(unrestrictedUnidentifiedAcces, account.isUnrestrictedUnidentifiedAccess())
);
}
private CompletableFuture<?> modifyAccount(final UUID uuid, final Consumer<Account> accountMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.get(uuid).get();
accountsManager.update(account, accountMutation);
}, mutationExecutor);
}
private CompletableFuture<?> modifyDevice(final UUID uuid, final long deviceId, final Consumer<Device> deviceMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.get(uuid).get();
accountsManager.updateDevice(account, deviceId, deviceMutation);
}, mutationExecutor);
}
private Account generateAccount(String number, UUID uuid) {
Device device = generateDevice(1);
return generateAccount(number, uuid, Collections.singleton(device));
}
private Account generateAccount(String number, UUID uuid, Set<Device> devices) {
byte[] unidentifiedAccessKey = new byte[16];
Random random = new Random(System.currentTimeMillis());
Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255));
return new Account(number, uuid, devices, unidentifiedAccessKey);
}
private Device generateDevice(long id) {
Random random = new Random(System.currentTimeMillis());
SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(), "testSignature-" + random.nextInt());
return new Device(id, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(), "testSalt-" + random.nextInt(),
"testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(), random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(), "testUserAgent-" + random.nextInt() , 0, new Device.DeviceCapabilities(random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean(), random.nextBoolean()));
}
}

View File

@@ -5,104 +5,120 @@
package org.whispersystems.textsecuregcm.tests.auth;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import java.time.Clock;
import java.time.Instant;
import java.util.Random;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class BaseAccountAuthenticatorTest {
import java.time.Clock;
import java.time.Instant;
import java.util.Random;
import java.util.Set;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
private final Random random = new Random(867_5309L);
private final long today = 1590451200000L;
private final long yesterday = today - 86_400_000L;
private final long oldTime = yesterday - 86_400_000L;
private final long currentTime = today + 68_000_000L;
class BaseAccountAuthenticatorTest {
private AccountsManager accountsManager;
private BaseAccountAuthenticator baseAccountAuthenticator;
private Clock clock;
private Account acct1;
private Account acct2;
private Account oldAccount;
private final Random random = new Random(867_5309L);
private final long today = 1590451200000L;
private final long yesterday = today - 86_400_000L;
private final long oldTime = yesterday - 86_400_000L;
private final long currentTime = today + 68_000_000L;
@Before
public void setup() {
accountsManager = mock(AccountsManager.class);
clock = mock(Clock.class);
baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock);
private AccountsManager accountsManager;
private BaseAccountAuthenticator baseAccountAuthenticator;
private Clock clock;
private Account acct1;
private Account acct2;
private Account oldAccount;
acct1 = new Account("+14088675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
acct2 = new Account("+14098675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
oldAccount = new Account("+14108675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, oldTime, 0, null, 0, null)), null);
}
@BeforeEach
void setup() {
accountsManager = mock(AccountsManager.class);
clock = mock(Clock.class);
baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock);
@Test
public void testUpdateLastSeenMiddleOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime));
acct1 = new Account("+14088675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
acct2 = new Account("+14098675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
oldAccount = new Account("+14108675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, oldTime, 0, null, 0, null)), null);
baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get());
AccountsHelper.setupMockUpdate(accountsManager);
}
verify(accountsManager, never()).update(acct1);
verify(accountsManager).update(acct2);
@Test
void testUpdateLastSeenMiddleOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime));
assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday);
assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
}
final Device device1 = acct1.getDevices().stream().findFirst().get();
final Device device2 = acct2.getDevices().stream().findFirst().get();
@Test
public void testUpdateLastSeenStartOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
baseAccountAuthenticator.updateLastSeen(acct1, device1);
baseAccountAuthenticator.updateLastSeen(acct2, device2);
baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get());
verify(accountsManager, never()).updateDevice(eq(acct1), anyLong(), any());
verify(accountsManager).updateDevice(eq(acct2), anyLong(), any());
verify(accountsManager, never()).update(acct1);
verify(accountsManager, never()).update(acct2);
assertThat(device1.getLastSeen()).isEqualTo(yesterday);
assertThat(device2.getLastSeen()).isEqualTo(today);
}
assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday);
assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday);
}
@Test
void testUpdateLastSeenStartOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
@Test
public void testUpdateLastSeenEndOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today + 86_400_000L - 1));
final Device device1 = acct1.getDevices().stream().findFirst().get();
final Device device2 = acct2.getDevices().stream().findFirst().get();
baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct1, device1);
baseAccountAuthenticator.updateLastSeen(acct2, device2);
verify(accountsManager).update(acct1);
verify(accountsManager).update(acct2);
verify(accountsManager, never()).updateDevice(eq(acct1), anyLong(), any());
verify(accountsManager, never()).updateDevice(eq(acct2), anyLong(), any());
assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
}
assertThat(device1.getLastSeen()).isEqualTo(yesterday);
assertThat(device2.getLastSeen()).isEqualTo(yesterday);
}
@Test
public void testNeverWriteYesterday() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
@Test
void testUpdateLastSeenEndOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today + 86_400_000L - 1));
baseAccountAuthenticator.updateLastSeen(oldAccount, oldAccount.getDevices().stream().findFirst().get());
final Device device1 = acct1.getDevices().stream().findFirst().get();
final Device device2 = acct2.getDevices().stream().findFirst().get();
verify(accountsManager).update(oldAccount);
baseAccountAuthenticator.updateLastSeen(acct1, device1);
baseAccountAuthenticator.updateLastSeen(acct2, device2);
assertThat(oldAccount.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
}
verify(accountsManager).updateDevice(eq(acct1), anyLong(), any());
verify(accountsManager).updateDevice(eq(acct2), anyLong(), any());
assertThat(device1.getLastSeen()).isEqualTo(today);
assertThat(device2.getLastSeen()).isEqualTo(today);
}
@Test
void testNeverWriteYesterday() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
final Device device = oldAccount.getDevices().stream().findFirst().get();
baseAccountAuthenticator.updateLastSeen(oldAccount, device);
verify(accountsManager).updateDevice(eq(oldAccount), anyLong(), any());
assertThat(device.getLastSeen()).isEqualTo(today);
}
}

View File

@@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.*;
import static org.whispersystems.textsecuregcm.tests.util.AccountsHelper.eqUuid;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
@@ -78,6 +79,7 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@@ -171,6 +173,8 @@ class AccountControllerTest {
new SecureRandom().nextBytes(registration_lock_key);
AuthenticationCredentials registrationLockCredentials = new AuthenticationCredentials(Hex.toStringCondensed(registration_lock_key));
AccountsHelper.setupMockUpdate(accountsManager);
when(rateLimiters.getSmsDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVoiceDestinationDailyLimiter()).thenReturn(rateLimiter);
@@ -1352,7 +1356,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("c00lz0rz"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@@ -1368,7 +1372,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("z000"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@@ -1385,7 +1389,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("second"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@@ -1402,7 +1406,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(null);
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@@ -1419,7 +1423,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("third"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("fourth"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@@ -1544,7 +1548,7 @@ class AccountControllerTest {
.put(Entity.json(new AccountAttributes(false, 2222, null, null, null, true, null)));
assertThat(response.getStatus()).isEqualTo(204);
verify(directoryQueue, times(1)).refreshRegisteredUser(AuthHelper.UNDISCOVERABLE_ACCOUNT);
verify(directoryQueue, times(1)).refreshRegisteredUser(eqUuid(AuthHelper.UNDISCOVERABLE_ACCOUNT));
}
@Test
@@ -1557,7 +1561,7 @@ class AccountControllerTest {
.put(Entity.json(new AccountAttributes(false, 2222, null, null, null, false, null)));
assertThat(response.getStatus()).isEqualTo(204);
verify(directoryQueue, times(1)).refreshRegisteredUser(AuthHelper.VALID_ACCOUNT);
verify(directoryQueue, times(1)).refreshRegisteredUser(eqUuid(AuthHelper.VALID_ACCOUNT));
}
@Test

View File

@@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
@@ -47,6 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@@ -124,6 +126,8 @@ public class DeviceControllerTest {
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.empty());
when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount));
AccountsHelper.setupMockUpdate(accountsManager);
}
@Test
@@ -360,7 +364,7 @@ public class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(messagesManager, times(2)).clear(AuthHelper.VALID_UUID, deviceId);
verify(accountsManager, times(1)).update(AuthHelper.VALID_ACCOUNT);
verify(accountsManager, times(1)).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(AuthHelper.VALID_ACCOUNT).removeDevice(deviceId);
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
@@ -15,6 +16,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.AccountsHelper.eqUuid;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
@@ -61,6 +63,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ExtendWith(DropwizardExtensionsSupport.class)
@@ -114,13 +117,15 @@ class KeysControllerTest {
final Device sampleDevice3 = mock(Device.class);
final Device sampleDevice4 = mock(Device.class);
Set<Device> allDevices = new HashSet<Device>() {{
Set<Device> allDevices = new HashSet<>() {{
add(sampleDevice);
add(sampleDevice2);
add(sampleDevice3);
add(sampleDevice4);
}};
AccountsHelper.setupMockUpdate(accounts);
when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID);
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
@@ -142,7 +147,7 @@ class KeysControllerTest {
when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3));
when(existsAccount.getDevice(4L)).thenReturn(Optional.of(sampleDevice4));
when(existsAccount.getDevice(22L)).thenReturn(Optional.<Device>empty());
when(existsAccount.getDevice(22L)).thenReturn(Optional.empty());
when(existsAccount.getDevices()).thenReturn(allDevices);
when(existsAccount.isEnabled()).thenReturn(true);
when(existsAccount.getIdentityKey()).thenReturn("existsidentitykey");
@@ -256,7 +261,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
}
@Test
@@ -271,7 +276,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
}
@@ -578,7 +583,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class);
verify(keysDynamoDb).store(eq(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture());
verify(keysDynamoDb).store(eqUuid(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@@ -587,7 +592,7 @@ class KeysControllerTest {
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(AuthHelper.VALID_ACCOUNT);
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test
@@ -612,7 +617,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class);
verify(keysDynamoDb).store(eq(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture());
verify(keysDynamoDb).store(eqUuid(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@@ -621,7 +626,7 @@ class KeysControllerTest {
verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(AuthHelper.DISABLED_ACCOUNT);
verify(accounts).update(eq(AuthHelper.DISABLED_ACCOUNT), any());
}
@Test

View File

@@ -20,7 +20,8 @@ import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
@@ -29,9 +30,10 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.commons.lang3.RandomStringUtils;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.signal.zkgroup.InvalidInputException;
@@ -57,13 +59,15 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
public class ProfileControllerTest {
@ExtendWith(DropwizardExtensionsSupport.class)
class ProfileControllerTest {
private static AccountsManager accountsManager = mock(AccountsManager.class );
private static ProfilesManager profilesManager = mock(ProfilesManager.class);
@@ -82,30 +86,30 @@ public class ProfileControllerTest {
private Account profileAccount;
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ProfileController(rateLimiters,
accountsManager,
profilesManager,
usernamesManager,
dynamicConfigurationManager,
s3client,
postPolicyGenerator,
policySigner,
"profilesBucket",
zkProfileOperations,
true))
.build();
@ClassRule
public static final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ProfileController(rateLimiters,
accountsManager,
profilesManager,
usernamesManager,
dynamicConfigurationManager,
s3client,
postPolicyGenerator,
policySigner,
"profilesBucket",
zkProfileOperations,
true))
.build();
@Before
public void setup() throws Exception {
@BeforeEach
void setup() throws Exception {
reset(s3client);
AccountsHelper.setupMockUpdate(accountsManager);
dynamicPaymentsConfiguration = mock(DynamicPaymentsConfiguration.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
@@ -161,8 +165,13 @@ public class ProfileControllerTest {
clearInvocations(profilesManager);
}
@AfterEach
void teardown() {
reset(accountsManager);
}
@Test
public void testProfileGetByUuid() throws RateLimitExceededException {
void testProfileGetByUuid() throws RateLimitExceededException {
Profile profile= resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
.request()
@@ -180,7 +189,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileGetByNumber() throws RateLimitExceededException {
void testProfileGetByNumber() throws RateLimitExceededException {
Profile profile= resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_NUMBER_TWO)
.request()
@@ -201,7 +210,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileGetByUsername() throws RateLimitExceededException {
void testProfileGetByUsername() throws RateLimitExceededException {
Profile profile= resources.getJerseyTest()
.target("/v1/profile/username/n00bkiller")
.request()
@@ -220,7 +229,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileGetUnauthorized() {
void testProfileGetUnauthorized() {
Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_NUMBER_TWO)
.request()
@@ -230,7 +239,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileGetByUsernameUnauthorized() {
void testProfileGetByUsernameUnauthorized() {
Response response = resources.getJerseyTest()
.target("/v1/profile/username/n00bkiller")
.request()
@@ -241,7 +250,7 @@ public class ProfileControllerTest {
@Test
public void testProfileGetByUsernameNotFound() throws RateLimitExceededException {
void testProfileGetByUsernameNotFound() throws RateLimitExceededException {
Response response = resources.getJerseyTest()
.target("/v1/profile/username/n00bkillerzzzzz")
.request()
@@ -256,7 +265,7 @@ public class ProfileControllerTest {
@Test
public void testProfileGetDisabled() {
void testProfileGetDisabled() {
Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_NUMBER_TWO)
.request()
@@ -267,7 +276,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileCapabilities() {
void testProfileCapabilities() {
Profile profile= resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_NUMBER)
.request()
@@ -293,7 +302,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileNameDeprecated() {
void testSetProfileNameDeprecated() {
Response response = resources.getJerseyTest()
.target("/v1/profile/name/123456789012345678901234567890123456789012345678901234567890123456789012")
.request()
@@ -302,11 +311,11 @@ public class ProfileControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager, times(1)).update(any(Account.class));
verify(accountsManager, times(1)).update(any(Account.class), any());
}
@Test
public void testSetProfileNameExtendedDeprecated() {
void testSetProfileNameExtendedDeprecated() {
Response response = resources.getJerseyTest()
.target("/v1/profile/name/123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678")
.request()
@@ -315,11 +324,11 @@ public class ProfileControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager, times(1)).update(any(Account.class));
verify(accountsManager, times(1)).update(any(Account.class), any());
}
@Test
public void testSetProfileNameWrongSizeDeprecated() {
void testSetProfileNameWrongSizeDeprecated() {
Response response = resources.getJerseyTest()
.target("/v1/profile/name/1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890")
.request()
@@ -333,7 +342,7 @@ public class ProfileControllerTest {
/////
@Test
public void testSetProfileWantAvatarUpload() throws InvalidInputException {
void testSetProfileWantAvatarUpload() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
ProfileAvatarUploadAttributes uploadAttributes = resources.getJerseyTest()
@@ -358,7 +367,7 @@ public class ProfileControllerTest {
assertThat(profileArgumentCaptor.getValue().getAbout()).isNull(); }
@Test
public void testSetProfileWantAvatarUploadWithBadProfileSize() throws InvalidInputException {
void testSetProfileWantAvatarUploadWithBadProfileSize() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
Response response = resources.getJerseyTest()
@@ -372,7 +381,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileWithoutAvatarUpload() throws InvalidInputException {
void testSetProfileWithoutAvatarUpload() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
@@ -406,7 +415,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileWithAvatarUploadAndPreviousAvatar() throws InvalidInputException {
void testSetProfileWithAvatarUploadAndPreviousAvatar() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID_TWO);
ProfileAvatarUploadAttributes uploadAttributes= resources.getJerseyTest()
@@ -430,7 +439,7 @@ public class ProfileControllerTest {
assertThat(profileArgumentCaptor.getValue().getAbout()).isNull(); }
@Test
public void testSetProfileExtendedName() throws InvalidInputException {
void testSetProfileExtendedName() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID_TWO);
final String name = RandomStringUtils.randomAlphabetic(380);
@@ -456,7 +465,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileEmojiAndBioText() throws InvalidInputException {
void testSetProfileEmojiAndBioText() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
@@ -495,7 +504,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfilePaymentAddress() throws InvalidInputException {
void testSetProfilePaymentAddress() throws InvalidInputException {
when(dynamicPaymentsConfiguration.getAllowedCountryCodes())
.thenReturn(Set.of(Util.getCountryCode(AuthHelper.VALID_NUMBER_TWO)));
@@ -536,7 +545,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfilePaymentAddressCountryNotAllowed() throws InvalidInputException {
void testSetProfilePaymentAddressCountryNotAllowed() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
@@ -557,7 +566,7 @@ public class ProfileControllerTest {
}
@Test
public void testGetProfileByVersion() throws RateLimitExceededException {
void testGetProfileByVersion() throws RateLimitExceededException {
Profile profile = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO + "/validversion")
.request()
@@ -582,7 +591,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileUpdatesAccountCurrentVersion() throws InvalidInputException {
void testSetProfileUpdatesAccountCurrentVersion() throws InvalidInputException {
when(dynamicPaymentsConfiguration.getAllowedCountryCodes())
.thenReturn(Set.of(Util.getCountryCode(AuthHelper.VALID_NUMBER_TWO)));
@@ -606,7 +615,7 @@ public class ProfileControllerTest {
}
@Test
public void testGetProfileReturnsNoPaymentAddressIfCurrentVersionMismatch() {
void testGetProfileReturnsNoPaymentAddressIfCurrentVersionMismatch() {
when(profilesManager.get(AuthHelper.VALID_UUID_TWO, "validversion")).thenReturn(
Optional.of(new VersionedProfile(null, null, null, null, null, "paymentaddress", null)));
Profile profile = resources.getJerseyTest()

View File

@@ -23,6 +23,7 @@ import org.whispersystems.textsecuregcm.push.GcmMessage;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.SynchronousExecutorService;
import org.whispersystems.textsecuregcm.util.Util;
@@ -40,6 +41,8 @@ public class GCMSenderTest {
when(successResult.hasCanonicalRegistrationId()).thenReturn(false);
when(successResult.isSuccess()).thenReturn(true);
AccountsHelper.setupMockUpdate(accountsManager);
GcmMessage message = new GcmMessage("foo", "+12223334444", 1, GcmMessage.Type.NOTIFICATION, Optional.empty());
GCMSender gcmSender = new GCMSender(executorService, accountsManager, sender);
@@ -65,6 +68,8 @@ public class GCMSenderTest {
Account destinationAccount = mock(Account.class);
Device destinationDevice = mock(Device.class );
AccountsHelper.setupMockUpdate(accountsManager);
when(destinationAccount.getDevice(1)).thenReturn(Optional.of(destinationDevice));
when(accountsManager.get(destinationNumber)).thenReturn(Optional.of(destinationAccount));
when(destinationDevice.getGcmId()).thenReturn(gcmId);
@@ -85,7 +90,7 @@ public class GCMSenderTest {
verify(sender, times(1)).send(any(Message.class));
verify(accountsManager, times(1)).get(eq(destinationNumber));
verify(accountsManager, times(1)).update(eq(destinationAccount));
verify(accountsManager, times(1)).updateDevice(eq(destinationAccount), eq(1L), any());
verify(destinationDevice, times(1)).setUninstalledFeedbackTimestamp(eq(Util.todayInMillis()));
}
@@ -107,6 +112,8 @@ public class GCMSenderTest {
when(accountsManager.get(destinationNumber)).thenReturn(Optional.of(destinationAccount));
when(destinationDevice.getGcmId()).thenReturn(gcmId);
AccountsHelper.setupMockUpdate(accountsManager);
when(canonicalResult.isInvalidRegistrationId()).thenReturn(false);
when(canonicalResult.isUnregistered()).thenReturn(false);
when(canonicalResult.hasCanonicalRegistrationId()).thenReturn(true);
@@ -124,7 +131,7 @@ public class GCMSenderTest {
verify(sender, times(1)).send(any(Message.class));
verify(accountsManager, times(1)).get(eq(destinationNumber));
verify(accountsManager, times(1)).update(eq(destinationAccount));
verify(accountsManager, times(1)).updateDevice(eq(destinationAccount), eq(1L), any());
verify(destinationDevice, times(1)).setGcmId(eq(canonicalId));
}

View File

@@ -6,8 +6,10 @@
package org.whispersystems.textsecuregcm.tests.storage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -17,26 +19,26 @@ import java.util.HashSet;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
public class AccountTest {
class AccountTest {
private final Device oldMasterDevice = mock(Device.class);
private final Device recentMasterDevice = mock(Device.class);
private final Device agingSecondaryDevice = mock(Device.class);
private final Device oldMasterDevice = mock(Device.class);
private final Device recentMasterDevice = mock(Device.class);
private final Device agingSecondaryDevice = mock(Device.class);
private final Device recentSecondaryDevice = mock(Device.class);
private final Device oldSecondaryDevice = mock(Device.class);
private final Device oldSecondaryDevice = mock(Device.class);
private final Device gv2CapableDevice = mock(Device.class);
private final Device gv2IncapableDevice = mock(Device.class);
private final Device gv2CapableDevice = mock(Device.class);
private final Device gv2IncapableDevice = mock(Device.class);
private final Device gv2IncapableExpiredDevice = mock(Device.class);
private final Device gv1MigrationCapableDevice = mock(Device.class);
private final Device gv1MigrationIncapableDevice = mock(Device.class);
private final Device gv1MigrationIncapableDevice = mock(Device.class);
private final Device gv1MigrationIncapableExpiredDevice = mock(Device.class);
private final Device senderKeyCapableDevice = mock(Device.class);
@@ -47,8 +49,8 @@ public class AccountTest {
private final Device announcementGroupIncapableDevice = mock(Device.class);
private final Device announcementGroupIncapableExpiredDevice = mock(Device.class);
@Before
public void setup() {
@BeforeEach
void setup() {
when(oldMasterDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(366));
when(oldMasterDevice.isEnabled()).thenReturn(true);
when(oldMasterDevice.getId()).thenReturn(Device.MASTER_ID);
@@ -119,9 +121,9 @@ public class AccountTest {
}
@Test
public void testIsEnabled() {
final Device enabledMasterDevice = mock(Device.class);
final Device enabledLinkedDevice = mock(Device.class);
void testIsEnabled() {
final Device enabledMasterDevice = mock(Device.class);
final Device enabledLinkedDevice = mock(Device.class);
final Device disabledMasterDevice = mock(Device.class);
final Device disabledLinkedDevice = mock(Device.class);
@@ -144,7 +146,7 @@ public class AccountTest {
}
@Test
public void testCapabilities() {
void testCapabilities() {
Account uuidCapable = new Account("+14152222222", UUID.randomUUID(), new HashSet<Device>() {{
add(gv2CapableDevice);
}}, "1234".getBytes());
@@ -165,13 +167,13 @@ public class AccountTest {
}
@Test
public void testIsTransferSupported() {
final Device transferCapableMasterDevice = mock(Device.class);
final Device nonTransferCapableMasterDevice = mock(Device.class);
final Device transferCapableLinkedDevice = mock(Device.class);
void testIsTransferSupported() {
final Device transferCapableMasterDevice = mock(Device.class);
final Device nonTransferCapableMasterDevice = mock(Device.class);
final Device transferCapableLinkedDevice = mock(Device.class);
final DeviceCapabilities transferCapabilities = mock(DeviceCapabilities.class);
final DeviceCapabilities nonTransferCapabilities = mock(DeviceCapabilities.class);
final DeviceCapabilities transferCapabilities = mock(DeviceCapabilities.class);
final DeviceCapabilities nonTransferCapabilities = mock(DeviceCapabilities.class);
when(transferCapableMasterDevice.getId()).thenReturn(1L);
when(transferCapableMasterDevice.isMaster()).thenReturn(true);
@@ -213,10 +215,12 @@ public class AccountTest {
}
@Test
public void testDiscoverableByPhoneNumber() {
final Account account = new Account("+14152222222", UUID.randomUUID(), Collections.singleton(recentMasterDevice), "1234".getBytes());
void testDiscoverableByPhoneNumber() {
final Account account = new Account("+14152222222", UUID.randomUUID(), Collections.singleton(recentMasterDevice),
"1234".getBytes());
assertTrue("Freshly-loaded legacy accounts should be discoverable by phone number.", account.isDiscoverableByPhoneNumber());
assertTrue(account.isDiscoverableByPhoneNumber(),
"Freshly-loaded legacy accounts should be discoverable by phone number.");
account.setDiscoverableByPhoneNumber(false);
assertFalse(account.isDiscoverableByPhoneNumber());
@@ -226,21 +230,29 @@ public class AccountTest {
}
@Test
public void isGroupsV2Supported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertFalse(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
void isGroupsV2Supported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertFalse(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
}
@Test
public void isGv1MigrationSupported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertFalse(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
void isGv1MigrationSupported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertFalse(
new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertTrue(new Account("+18005551234", UUID.randomUUID(),
Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8))
.isGv1MigrationSupported());
}
@Test
public void isSenderKeySupported() {
void isSenderKeySupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(), Set.of(senderKeyCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isSenderKeySupported()).isTrue();
assertThat(new Account("+18005551234", UUID.randomUUID(), Set.of(senderKeyCapableDevice, senderKeyIncapableDevice),
@@ -251,7 +263,7 @@ public class AccountTest {
}
@Test
public void isAnnouncementGroupSupported() {
void isAnnouncementGroupSupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(),
Set.of(announcementGroupCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue();
@@ -262,4 +274,16 @@ public class AccountTest {
Set.of(announcementGroupCapableDevice, announcementGroupIncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue();
}
@Test
void stale() {
final Account account = new Account("+14151234567", UUID.randomUUID(), Collections.emptySet(), new byte[0]);
assertDoesNotThrow(account::getNumber);
account.markStale();
assertThrows(AssertionError.class, account::getNumber);
assertDoesNotThrow(account::getUuid);
}
}

View File

@@ -6,11 +6,14 @@
package org.whispersystems.textsecuregcm.tests.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@@ -22,13 +25,17 @@ import static org.mockito.Mockito.when;
import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.IOException;
import java.util.HashSet;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Consumer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAccountsDynamoDbMigrationConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
@@ -41,6 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsDynamoDb;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@@ -48,14 +56,22 @@ import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.tests.util.JsonHelpers;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
class AccountsManagerTest {
private DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
private static final Answer<?> ACCOUNT_UPDATE_ANSWER = (answer) -> {
// it is implicit in the update() contract is that a successful call will
// result in an incremented version
final Account updatedAccount = answer.getArgument(0, Account.class);
updatedAccount.setVersion(updatedAccount.getVersion() + 1);
return null;
};
@BeforeEach
void setup() {
@@ -326,7 +342,7 @@ class AccountsManagerTest {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testUpdate_dynamoDbMigration(boolean dynamoEnabled) {
void testUpdate_dynamoDbMigration(boolean dynamoEnabled) throws IOException {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
@@ -335,35 +351,59 @@ class AccountsManagerTest {
DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class);
MessagesManager messagesManager = mock(MessagesManager.class);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
enableDynamo(dynamoEnabled);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
// database fetches should always return new instances
when(accounts.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
when(accountsDynamoDb.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
doAnswer(ACCOUNT_UPDATE_ANSWER).when(accounts).update(any(Account.class));
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts,
directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
assertEquals(0, account.getDynamoDbMigrationVersion());
Account updatedAccount = accountsManager.update(account, a -> a.setProfileName("name"));
accountsManager.update(account);
assertThrows(AssertionError.class, account::getProfileName, "Account passed to update() should be stale");
assertEquals(1, account.getDynamoDbMigrationVersion());
assertNotSame(updatedAccount, account);
verify(accounts, times(1)).update(account);
verifyNoMoreInteractions(accounts);
verify(accountsDynamoDb, dynamoEnabled ? times(1) : never()).update(account);
if (dynamoEnabled) {
ArgumentCaptor<Account> argumentCaptor = ArgumentCaptor.forClass(Account.class);
verify(accountsDynamoDb, times(1)).update(argumentCaptor.capture());
assertEquals(uuid, argumentCaptor.getValue().getUuid());
} else {
verify(accountsDynamoDb, never()).update(any());
}
verify(accountsDynamoDb, dynamoEnabled ? times(1) : never()).get(uuid);
verifyNoMoreInteractions(accountsDynamoDb);
ArgumentCaptor<String> redisSetArgumentCapture = ArgumentCaptor.forClass(String.class);
verify(commands, times(4)).set(anyString(), redisSetArgumentCapture.capture());
Account firstAccountCached = JsonHelpers.fromJson(redisSetArgumentCapture.getAllValues().get(1), Account.class);
Account secondAccountCached = JsonHelpers.fromJson(redisSetArgumentCapture.getAllValues().get(3), Account.class);
// uuid is @JsonIgnore, so we need to set it for compareAccounts to work
firstAccountCached.setUuid(uuid);
secondAccountCached.setUuid(uuid);
assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.of(firstAccountCached), Optional.of(secondAccountCached)));
}
@Test
void testUpdate_dynamoConditionFailed() {
void testUpdate_dynamoMissing() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
@@ -382,25 +422,158 @@ class AccountsManagerTest {
enableDynamo(true);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
doThrow(ConditionalCheckFailedException.class).when(accountsDynamoDb).update(any(Account.class));
when(accountsDynamoDb.get(uuid)).thenReturn(Optional.empty());
doAnswer(ACCOUNT_UPDATE_ANSWER).when(accounts).update(any());
doAnswer(ACCOUNT_UPDATE_ANSWER).when(accountsDynamoDb).update(any());
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts,
directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
assertEquals(0, account.getDynamoDbMigrationVersion());
accountsManager.update(account);
assertEquals(1, account.getDynamoDbMigrationVersion());
Account updatedAccount = accountsManager.update(account, a -> {});
verify(accounts, times(1)).update(account);
verifyNoMoreInteractions(accounts);
verify(accountsDynamoDb, times(1)).update(account);
verify(accountsDynamoDb, times(1)).create(account);
verify(accountsDynamoDb, never()).update(account);
verify(accountsDynamoDb, times(1)).get(uuid);
verifyNoMoreInteractions(accountsDynamoDb);
assertEquals(1, updatedAccount.getVersion());
}
@Test
void testUpdate_optimisticLockingFailure() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
AccountsDynamoDb accountsDynamoDb = mock(AccountsDynamoDb.class);
DeletedAccounts deletedAccounts = mock(DeletedAccounts.class);
DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class);
MessagesManager messagesManager = mock(MessagesManager.class);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
enableDynamo(true);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class)
.doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accounts).update(any());
when(accountsDynamoDb.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class)
.doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accountsDynamoDb).update(any());
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
account = accountsManager.update(account, a -> a.setProfileName("name"));
assertEquals(1, account.getVersion());
assertEquals("name", account.getProfileName());
verify(accounts, times(1)).get(uuid);
verify(accounts, times(2)).update(any());
verifyNoMoreInteractions(accounts);
// dynamo has an extra get() because the account is fetched before every update
verify(accountsDynamoDb, times(2)).get(uuid);
verify(accountsDynamoDb, times(2)).update(any());
verifyNoMoreInteractions(accountsDynamoDb);
}
@Test
void testUpdate_dynamoOptimisticLockingFailureDuringCreate() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
AccountsDynamoDb accountsDynamoDb = mock(AccountsDynamoDb.class);
DeletedAccounts deletedAccounts = mock(DeletedAccounts.class);
DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class);
MessagesManager messagesManager = mock(MessagesManager.class);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
enableDynamo(true);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accountsDynamoDb.get(uuid)).thenReturn(Optional.empty())
.thenReturn(Optional.of(account));
when(accountsDynamoDb.create(any())).thenThrow(ContestedOptimisticLockException.class);
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
accountsManager.update(account, a -> {});
verify(accounts, times(1)).update(account);
verifyNoMoreInteractions(accounts);
verify(accountsDynamoDb, times(1)).get(uuid);
verifyNoMoreInteractions(accountsDynamoDb);
}
@Test
void testUpdateDevice() throws Exception {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
AccountsDynamoDb accountsDynamoDb = mock(AccountsDynamoDb.class);
DeletedAccounts deletedAccounts = mock(DeletedAccounts.class);
DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class);
MessagesManager messagesManager = mock(MessagesManager.class);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.empty(), Optional.empty()));
final UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
when(accounts.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
assertTrue(account.getDevices().isEmpty());
Device enabledDevice = new Device();
enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(new SignedPreKey(1L, "key", "signature"));
enabledDevice.setLastSeen(System.currentTimeMillis());
final long deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId);
account.addDevice(enabledDevice);
@SuppressWarnings("unchecked") Consumer<Device> deviceUpdater = mock(Consumer.class);
@SuppressWarnings("unchecked") Consumer<Device> unknownDeviceUpdater = mock(Consumer.class);
account = accountsManager.updateDevice(account, deviceId, deviceUpdater);
account = accountsManager.updateDevice(account, deviceId, d -> d.setName("deviceName"));
assertEquals("deviceName", account.getDevice(deviceId).get().getName());
verify(deviceUpdater, times(1)).accept(any(Device.class));
accountsManager.updateDevice(account, account.getNextDeviceId(), unknownDeviceUpdater);
verify(unknownDeviceUpdater, never()).accept(any(Device.class));
}
@Test
void testCompareAccounts() throws Exception {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
@@ -479,9 +652,11 @@ class AccountsManagerTest {
assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.of(a1), Optional.of(a2)));
a1.setDynamoDbMigrationVersion(1);
a1.setVersion(1);
assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.of(a1), Optional.of(a2)));
assertEquals(Optional.of("version"), accountsManager.compareAccounts(Optional.of(a1), Optional.of(a2)));
a2.setVersion(1);
a2.setProfileName("name");

View File

@@ -167,6 +167,12 @@ public class AccountsTest {
accounts.update(account);
account.setProfileName("profileName");
accounts.update(account);
assertThat(account.getVersion()).isEqualTo(2);
Optional<Account> retrieved = accounts.get("+14151112222");
assertThat(retrieved.isPresent()).isTrue();
@@ -359,6 +365,7 @@ public class AccountsTest {
assertThat(result.getNumber()).isEqualTo(number);
assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen());
assertThat(result.getUuid()).isEqualTo(uuid);
assertThat(result.getVersion()).isEqualTo(expecting.getVersion());
assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue();
for (Device expectingDevice : expecting.getDevices()) {

View File

@@ -5,16 +5,16 @@
package org.whispersystems.textsecuregcm.tests.storage;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
import org.whispersystems.textsecuregcm.util.Util;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import java.util.Collections;
import java.util.List;
@@ -22,11 +22,20 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.Util;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*;
public class PushFeedbackProcessorTest {
class PushFeedbackProcessorTest {
private AccountsManager accountsManager = mock(AccountsManager.class);
private DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
@@ -46,8 +55,10 @@ public class PushFeedbackProcessorTest {
private Device stillActiveDevice = mock(Device.class);
private Device undiscoverableDevice = mock(Device.class);
@Before
public void setup() {
@BeforeEach
void setup() {
AccountsHelper.setupMockUpdate(accountsManager);
when(uninstalledDevice.getUninstalledFeedbackTimestamp()).thenReturn(Util.todayInMillis() - TimeUnit.DAYS.toMillis(2));
when(uninstalledDevice.getLastSeen()).thenReturn(Util.todayInMillis() - TimeUnit.DAYS.toMillis(2));
when(uninstalledDeviceTwo.getUninstalledFeedbackTimestamp()).thenReturn(Util.todayInMillis() - TimeUnit.DAYS.toMillis(3));
@@ -85,7 +96,7 @@ public class PushFeedbackProcessorTest {
@Test
public void testEmpty() throws AccountDatabaseCrawlerRestartException {
void testEmpty() throws AccountDatabaseCrawlerRestartException {
PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue);
processor.timeAndProcessCrawlChunk(Optional.of(UUID.randomUUID()), Collections.emptyList());
@@ -94,7 +105,7 @@ public class PushFeedbackProcessorTest {
}
@Test
public void testUpdate() throws AccountDatabaseCrawlerRestartException {
void testUpdate() throws AccountDatabaseCrawlerRestartException {
PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue);
processor.timeAndProcessCrawlChunk(Optional.of(UUID.randomUUID()), List.of(uninstalledAccount, mixedAccount, stillActiveAccount, freshAccount, cleanAccount, undiscoverableAccount));
@@ -102,7 +113,7 @@ public class PushFeedbackProcessorTest {
verify(uninstalledDevice).setGcmId(isNull());
verify(uninstalledDevice).setFetchesMessages(eq(false));
verify(accountsManager).update(eq(uninstalledAccount));
verify(accountsManager).update(eq(uninstalledAccount), any());
verify(uninstalledDeviceTwo).setApnId(isNull());
verify(uninstalledDeviceTwo).setGcmId(isNull());
@@ -112,33 +123,35 @@ public class PushFeedbackProcessorTest {
verify(installedDevice, never()).setGcmId(any());
verify(installedDevice, never()).setFetchesMessages(anyBoolean());
verify(accountsManager).update(eq(mixedAccount));
verify(accountsManager).update(eq(mixedAccount), any());
verify(recentUninstalledDevice, never()).setApnId(any());
verify(recentUninstalledDevice, never()).setGcmId(any());
verify(recentUninstalledDevice, never()).setFetchesMessages(anyBoolean());
verify(accountsManager, never()).update(eq(freshAccount));
verify(accountsManager, never()).update(eq(freshAccount), any());
verify(installedDeviceTwo, never()).setApnId(any());
verify(installedDeviceTwo, never()).setGcmId(any());
verify(installedDeviceTwo, never()).setFetchesMessages(anyBoolean());
verify(accountsManager, never()).update(eq(cleanAccount));
verify(accountsManager, never()).update(eq(cleanAccount), any());
verify(stillActiveDevice).setUninstalledFeedbackTimestamp(eq(0L));
verify(stillActiveDevice, never()).setApnId(any());
verify(stillActiveDevice, never()).setGcmId(any());
verify(stillActiveDevice, never()).setFetchesMessages(anyBoolean());
verify(accountsManager).update(eq(stillActiveAccount));
verify(accountsManager).update(eq(stillActiveAccount), any());
final ArgumentCaptor<List<Account>> refreshedAccountArgumentCaptor = ArgumentCaptor.forClass(List.class);
verify(directoryQueue).refreshRegisteredUsers(refreshedAccountArgumentCaptor.capture());
assertTrue(refreshedAccountArgumentCaptor.getValue().containsAll(List.of(undiscoverableAccount, uninstalledAccount)));
final List<UUID> refreshedUuids = refreshedAccountArgumentCaptor.getValue().stream()
.map(Account::getUuid)
.collect(Collectors.toList());
assertTrue(refreshedUuids.containsAll(List.of(undiscoverableAccount.getUuid(), uninstalledAccount.getUuid())));
}
}

View File

@@ -0,0 +1,148 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockingDetails;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.function.Consumer;
import org.mockito.MockingDetails;
import org.mockito.stubbing.Stubbing;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class AccountsHelper {
public static void setupMockUpdate(final AccountsManager mockAccountsManager) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
answer.getArgument(1, Consumer.class).accept(account);
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateDevice(any(), anyLong(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final Long deviceId = answer.getArgument(1, Long.class);
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
return copyAndMarkStale(account);
});
}
private static Account copyAndMarkStale(Account account) throws IOException {
MockingDetails mockingDetails = mockingDetails(account);
final Account updatedAccount;
if (mockingDetails.isMock()) {
updatedAccount = mock(Account.class);
// its not possible to make `account` behave as if it were stale, because we use static mocks in AuthHelper
for (Stubbing stubbing : mockingDetails.getStubbings()) {
switch (stubbing.getInvocation().getMethod().getName()) {
case "getUuid": {
when(updatedAccount.getUuid()).thenAnswer(stubbing);
break;
}
case "getNumber": {
when(updatedAccount.getNumber()).thenAnswer(stubbing);
break;
}
case "getDevices": {
when(updatedAccount.getDevices())
.thenAnswer(stubbing);
break;
}
case "getDevice": {
when(updatedAccount.getDevice(stubbing.getInvocation().getArgument(0)))
.thenAnswer(stubbing);
break;
}
case "getMasterDevice": {
when(updatedAccount.getMasterDevice()).thenAnswer(stubbing);
break;
}
case "getAuthenticatedDevice": {
when(updatedAccount.getAuthenticatedDevice()).thenAnswer(stubbing);
break;
}
case "isEnabled": {
when(updatedAccount.isEnabled()).thenAnswer(stubbing);
break;
}
case "isDiscoverableByPhoneNumber": {
when(updatedAccount.isDiscoverableByPhoneNumber()).thenAnswer(stubbing);
break;
}
case "getNextDeviceId": {
when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing);
break;
}
case "isGroupsV2Supported": {
when(updatedAccount.isGroupsV2Supported()).thenAnswer(stubbing);
break;
}
case "isGv1MigrationSupported": {
when(updatedAccount.isGv1MigrationSupported()).thenAnswer(stubbing);
break;
}
case "isSenderKeySupported": {
when(updatedAccount.isSenderKeySupported()).thenAnswer(stubbing);
break;
}
case "isAnnouncementGroupSupported": {
when(updatedAccount.isAnnouncementGroupSupported()).thenAnswer(stubbing);
break;
}
case "getEnabledDeviceCount": {
when(updatedAccount.getEnabledDeviceCount()).thenAnswer(stubbing);
break;
}
case "getRelay": {
// TODO unused
when(updatedAccount.getRelay()).thenAnswer(stubbing);
break;
}
case "getRegistrationLock": {
when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing);
break;
}
case "getIdentityKey": {
when(updatedAccount.getIdentityKey()).thenAnswer(stubbing);
break;
}
default: {
throw new IllegalArgumentException(
"unsupported method: Account#" + stubbing.getInvocation().getMethod().getName());
}
}
}
} else {
final ObjectMapper mapper = SystemMapper.getMapper();
updatedAccount = mapper.readValue(mapper.writeValueAsBytes(account), Account.class);
updatedAccount.setNumber(account.getNumber());
account.markStale();
}
return updatedAccount;
}
public static Account eqUuid(Account value) {
return argThat(other -> other.getUuid().equals(value.getUuid()));
}
}