Add a framework for running experiments to improve push notification reliability

This commit is contained in:
Jon Chambers
2024-07-25 11:36:05 -04:00
committed by GitHub
parent 1fe6dac760
commit 4ebad2c473
16 changed files with 1489 additions and 8 deletions

View File

@@ -0,0 +1,331 @@
package org.whispersystems.textsecuregcm.experiment;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.time.Clock;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import reactor.util.function.Tuples;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.Select;
import javax.annotation.Nullable;
class PushNotificationExperimentSamplesTest {
private PushNotificationExperimentSamples pushNotificationExperimentSamples;
@RegisterExtension
public static final DynamoDbExtension DYNAMO_DB_EXTENSION =
new DynamoDbExtension(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES);
private record TestDeviceState(int bounciness) {
}
@BeforeEach
void setUp() {
pushNotificationExperimentSamples =
new PushNotificationExperimentSamples(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName(),
Clock.systemUTC());
}
@Test
void recordInitialState() throws JsonProcessingException {
final String experimentName = "test-experiment";
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean();
final int bounciness = ThreadLocalRandom.current().nextInt();
assertTrue(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(bounciness))
.join(),
"Attempt to record an initial state should succeed for entirely new records");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class));
assertTrue(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(bounciness))
.join(),
"Attempt to re-record an initial state should succeed for existing-but-unchanged records");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Recorded initial state should be unchanged after repeated write");
assertFalse(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(bounciness + 1))
.join(),
"Attempt to record a conflicting initial state should fail");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Recorded initial state should be unchanged after unsuccessful write");
assertFalse(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
!inExperimentGroup,
new TestDeviceState(bounciness))
.join(),
"Attempt to record a new group assignment should fail");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Recorded initial state should be unchanged after unsuccessful write");
final int finalBounciness = bounciness + 17;
pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
deviceId,
experimentName,
new TestDeviceState(finalBounciness))
.join();
assertFalse(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(bounciness))
.join(),
"Attempt to record an initial state should fail for samples with final states");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), new TestDeviceState(finalBounciness)),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Recorded initial state should be unchanged after unsuccessful write");
}
@Test
void recordFinalState() throws JsonProcessingException {
final String experimentName = "test-experiment";
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean();
final int initialBounciness = ThreadLocalRandom.current().nextInt();
final int finalBounciness = initialBounciness + 17;
{
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
final PushNotificationExperimentSample<TestDeviceState> returnedSample =
pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
deviceId,
experimentName,
new TestDeviceState(finalBounciness))
.join();
final PushNotificationExperimentSample<TestDeviceState> expectedSample =
new PushNotificationExperimentSample<>(inExperimentGroup,
new TestDeviceState(initialBounciness),
new TestDeviceState(finalBounciness));
assertEquals(expectedSample, returnedSample,
"Attempt to update existing sample without final state should succeed");
assertEquals(expectedSample, getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Attempt to update existing sample without final state should be persisted");
}
assertThrows(CompletionException.class, () -> pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
(byte) (deviceId + 1),
experimentName,
new TestDeviceState(finalBounciness))
.join(),
"Attempts to record a final state without an initial state should fail");
}
@SuppressWarnings("SameParameterValue")
private <T> PushNotificationExperimentSample<T> getSample(final UUID accountIdentifier,
final byte deviceId,
final String experimentName,
final Class<T> stateClass) throws JsonProcessingException {
final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder()
.tableName(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName())
.key(Map.of(
PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName),
PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID, PushNotificationExperimentSamples.buildSortKey(accountIdentifier, deviceId)))
.build());
final boolean inExperimentGroup =
response.item().get(PushNotificationExperimentSamples.ATTR_IN_EXPERIMENT_GROUP).bool();
final T initialState =
SystemMapper.jsonMapper().readValue(response.item().get(PushNotificationExperimentSamples.ATTR_INITIAL_STATE).s(), stateClass);
@Nullable final T finalState = response.item().containsKey(PushNotificationExperimentSamples.ATTR_FINAL_STATE)
? SystemMapper.jsonMapper().readValue(response.item().get(PushNotificationExperimentSamples.ATTR_FINAL_STATE).s(), stateClass)
: null;
return new PushNotificationExperimentSample<>(inExperimentGroup, initialState, finalState);
}
@Test
void getDevicesPendingFinalState() throws JsonProcessingException {
final String experimentName = "test-experiment";
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean();
final int initialBounciness = ThreadLocalRandom.current().nextInt();
//noinspection DataFlowIssue
assertTrue(pushNotificationExperimentSamples.getDevicesPendingFinalState(experimentName).collectList().block().isEmpty());
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
(byte) (deviceId + 1),
experimentName + "-different",
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
assertEquals(List.of(Tuples.of(accountIdentifier, deviceId)),
pushNotificationExperimentSamples.getDevicesPendingFinalState(experimentName).collectList().block());
}
@Test
void getFinishedSamples() throws JsonProcessingException {
final String experimentName = "test-experiment";
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean();
final int initialBounciness = ThreadLocalRandom.current().nextInt();
final int finalBounciness = initialBounciness + 17;
//noinspection DataFlowIssue
assertTrue(pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block().isEmpty());
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
//noinspection DataFlowIssue
assertTrue(pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block().isEmpty(),
"Publisher should not return unfinished samples");
pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
deviceId,
experimentName,
new TestDeviceState(finalBounciness))
.join();
final List<PushNotificationExperimentSample<TestDeviceState>> expectedSamples =
List.of(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(initialBounciness), new TestDeviceState(finalBounciness)));
assertEquals(
expectedSamples,
pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block(),
"Publisher should return finished samples");
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName + "-different",
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
deviceId,
experimentName + "-different",
new TestDeviceState(finalBounciness))
.join();
assertEquals(
expectedSamples,
pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block(),
"Publisher should return finished samples only from named experiment");
}
@Test
void discardSamples() throws JsonProcessingException {
final String discardSamplesExperimentName = "discard-experiment";
final String retainSamplesExperimentName = "retain-experiment";
final int sampleCount = 16;
for (int i = 0; i < sampleCount; i++) {
pushNotificationExperimentSamples.recordInitialState(UUID.randomUUID(),
Device.PRIMARY_ID,
discardSamplesExperimentName,
ThreadLocalRandom.current().nextBoolean(),
new TestDeviceState(ThreadLocalRandom.current().nextInt()))
.join();
pushNotificationExperimentSamples.recordInitialState(UUID.randomUUID(),
Device.PRIMARY_ID,
retainSamplesExperimentName,
ThreadLocalRandom.current().nextBoolean(),
new TestDeviceState(ThreadLocalRandom.current().nextInt()))
.join();
}
pushNotificationExperimentSamples.discardSamples(discardSamplesExperimentName, 1).join();
{
final QueryResponse queryResponse = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().query(QueryRequest.builder()
.tableName(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName())
.keyConditionExpression("#experiment = :experiment")
.expressionAttributeNames(Map.of("#experiment", PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME))
.expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(discardSamplesExperimentName)))
.select(Select.COUNT)
.build())
.join();
assertEquals(0, queryResponse.count());
}
{
final QueryResponse queryResponse = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().query(QueryRequest.builder()
.tableName(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName())
.keyConditionExpression("#experiment = :experiment")
.expressionAttributeNames(Map.of("#experiment", PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME))
.expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(retainSamplesExperimentName)))
.select(Select.COUNT)
.build())
.join();
assertEquals(sampleCount, queryResponse.count());
}
}
}

View File

@@ -9,6 +9,7 @@ import java.util.Collections;
import java.util.List;
import org.whispersystems.textsecuregcm.backup.BackupsDb;
import org.whispersystems.textsecuregcm.scheduler.JobScheduler;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
@@ -141,6 +142,20 @@ public final class DynamoDbExtensionSchema {
.build()),
List.of(), List.of()),
PUSH_NOTIFICATION_EXPERIMENT_SAMPLES("push_notification_experiment_samples_test",
PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME,
PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID,
List.of(
AttributeDefinition.builder()
.attributeName(PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME)
.attributeType(ScalarAttributeType.S)
.build(),
AttributeDefinition.builder()
.attributeName(PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
REPEATED_USE_EC_SIGNED_PRE_KEYS("repeated_use_signed_ec_pre_keys_test",
RepeatedUseSignedPreKeyStore.KEY_ACCOUNT_UUID,
RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID,

View File

@@ -0,0 +1,252 @@
package org.whispersystems.textsecuregcm.workers;
import com.fasterxml.jackson.core.JsonProcessingException;
import net.sourceforge.argparse4j.inf.Namespace;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSample;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
import reactor.util.function.Tuples;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
class FinishPushNotificationExperimentCommandTest {
private CommandDependencies commandDependencies;
private PushNotificationExperiment<String> experiment;
private FinishPushNotificationExperimentCommand<String> finishPushNotificationExperimentCommand;
private static final String EXPERIMENT_NAME = "test";
private static class TestFinishPushNotificationExperimentCommand extends FinishPushNotificationExperimentCommand<String> {
public TestFinishPushNotificationExperimentCommand(final PushNotificationExperiment<String> experiment) {
super("test-finish-push-notification-experiment",
"Test start push notification experiment command",
(ignoredDependencies, ignoredConfiguration) -> experiment);
}
}
@BeforeEach
void setUp() throws JsonProcessingException {
final AccountsManager accountsManager = mock(AccountsManager.class);
final PushNotificationExperimentSamples pushNotificationExperimentSamples =
mock(PushNotificationExperimentSamples.class);
when(pushNotificationExperimentSamples.recordFinalState(any(), anyByte(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(new PushNotificationExperimentSample<>(true, "test", "test")));
commandDependencies = new CommandDependencies(accountsManager,
null,
null,
null,
null,
null,
null,
null,
pushNotificationExperimentSamples,
null,
null,
null,
null,
null);
//noinspection unchecked
experiment = mock(PushNotificationExperiment.class);
when(experiment.getExperimentName()).thenReturn(EXPERIMENT_NAME);
when(experiment.getState(any(), any())).thenReturn("test");
finishPushNotificationExperimentCommand = new TestFinishPushNotificationExperimentCommand(experiment);
}
@Test
void run() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(account, device);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runMissingAccount() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(null, null);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runMissingDevice() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Account account = mock(Account.class);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(account, null);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runAccountFetchRetry() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(commandDependencies.accountsManager(), times(3)).getByAccountIdentifierAsync(accountIdentifier);
verify(experiment).getState(account, device);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runStoreSampleRetry() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.completedFuture(new PushNotificationExperimentSample<>(true, "test", "test")));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(account, device);
verify(commandDependencies.pushNotificationExperimentSamples(), times(3))
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runMissingInitialSample() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(ConditionalCheckFailedException.builder().build()));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(account, device);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
}

View File

@@ -0,0 +1,166 @@
package org.whispersystems.textsecuregcm.workers;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import net.sourceforge.argparse4j.inf.Namespace;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
class StartPushNotificationExperimentCommandTest {
private PushNotificationExperimentSamples pushNotificationExperimentSamples;
private PushNotificationExperiment<String> experiment;
private StartPushNotificationExperimentCommand<String> startPushNotificationExperimentCommand;
// Taken together, these parameters will produce a device that's enrolled in the experimental group (as opposed to the
// control group) for an experiment.
private static final UUID ACCOUNT_IDENTIFIER = UUID.fromString("341fb18f-9dee-4181-bc40-e485958341d3");
private static final byte DEVICE_ID = Device.PRIMARY_ID;
private static final String EXPERIMENT_NAME = "test";
private static class TestStartPushNotificationExperimentCommand extends StartPushNotificationExperimentCommand<String> {
private final CommandDependencies commandDependencies;
public TestStartPushNotificationExperimentCommand(
final PushNotificationExperimentSamples pushNotificationExperimentSamples,
final PushNotificationExperiment<String> experiment) {
super("test-start-push-notification-experiment",
"Test start push notification experiment command",
(ignoredDependencies, ignoredConfiguration) -> experiment);
this.commandDependencies = new CommandDependencies(null,
null,
null,
null,
null,
null,
null,
null,
pushNotificationExperimentSamples,
null,
null,
null,
null,
null);
}
@Override
protected Namespace getNamespace() {
return new Namespace(Map.of(StartPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1));
}
@Override
protected CommandDependencies getCommandDependencies() {
return commandDependencies;
}
}
@BeforeEach
void setUp() {
//noinspection unchecked
experiment = mock(PushNotificationExperiment.class);
when(experiment.getExperimentName()).thenReturn(EXPERIMENT_NAME);
when(experiment.isDeviceEligible(any(), any())).thenReturn(CompletableFuture.completedFuture(true));
when(experiment.getState(any(), any())).thenReturn("test");
when(experiment.applyExperimentTreatment(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
pushNotificationExperimentSamples = mock(PushNotificationExperimentSamples.class);
try {
when(pushNotificationExperimentSamples.recordInitialState(any(), anyByte(), any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
} catch (final JsonProcessingException e) {
throw new AssertionError(e);
}
startPushNotificationExperimentCommand =
new TestStartPushNotificationExperimentCommand(pushNotificationExperimentSamples, experiment);
}
@Test
void crawlAccounts() {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER);
when(account.getDevices()).thenReturn(List.of(device));
assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account)));
verify(experiment).applyExperimentTreatment(account, device);
}
@Test
void crawlAccountsExistingSample() throws JsonProcessingException {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER);
when(account.getDevices()).thenReturn(List.of(device));
when(pushNotificationExperimentSamples.recordInitialState(any(), anyByte(), any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(false));
assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account)));
verify(experiment, never()).applyExperimentTreatment(account, device);
}
@Test
void crawlAccountsSampleRetry() throws JsonProcessingException {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER);
when(account.getDevices()).thenReturn(List.of(device));
when(pushNotificationExperimentSamples.recordInitialState(any(), anyByte(), any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.completedFuture(true));
assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account)));
verify(experiment).applyExperimentTreatment(account, device);
verify(pushNotificationExperimentSamples, times(3))
.recordInitialState(ACCOUNT_IDENTIFIER, DEVICE_ID, EXPERIMENT_NAME, true, "test");
}
@Test
void crawlAccountsExperimentException() {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER);
when(account.getDevices()).thenReturn(List.of(device));
when(experiment.applyExperimentTreatment(account, device))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account)));
verify(experiment).applyExperimentTreatment(account, device);
}
}