Tune the "finish push notification experiment" command

This commit is contained in:
Jon Chambers
2024-08-05 15:02:24 -04:00
committed by GitHub
parent 0e4625ef88
commit 8c61d45206
8 changed files with 298 additions and 274 deletions

View File

@@ -68,6 +68,11 @@ public class NotifyIdleDevicesWithoutMessagesPushNotificationExperiment implemen
.thenApply(mayHavePersistedMessages -> !mayHavePersistedMessages);
}
@Override
public Class<DeviceLastSeenState> getStateClass() {
return DeviceLastSeenState.class;
}
@VisibleForTesting
static boolean hasPushToken(final Device device) {
// Exclude VOIP tokens since they have their own, distinct delivery mechanism

View File

@@ -1,3 +1,4 @@
package org.whispersystems.textsecuregcm.experiment;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -32,6 +33,13 @@ public interface PushNotificationExperiment<T> {
*/
CompletableFuture<Boolean> isDeviceEligible(Account account, Device device);
/**
* Returns the class of the state object stored for this experiment.
*
* @return the class of the state object stored for this experiment
*/
Class<T> getStateClass();
/**
* Generates an experiment specific state "snapshot" of the given device. Experiment results are generally evaluated
* by comparing a device's state before a treatment is applied and its state after the treatment is applied.

View File

@@ -1,4 +1,11 @@
package org.whispersystems.textsecuregcm.experiment;
public record PushNotificationExperimentSample<T>(boolean inExperimentGroup, T initialState, T finalState) {
import javax.annotation.Nullable;
import java.util.UUID;
public record PushNotificationExperimentSample<T>(UUID accountIdentifier,
byte deviceId,
boolean inExperimentGroup,
T initialState,
@Nullable T finalState) {
}

View File

@@ -14,6 +14,7 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
import reactor.util.retry.Retry;
@@ -23,7 +24,6 @@ import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
@@ -130,106 +130,129 @@ public class PushNotificationExperimentSamples {
* @param finalState the final state of the object; must be serializable as a JSON text and of the same type as the
* previously-stored initial state
* @return a future that completes when the final state has been stored; yields a finished sample if an initial sample
* was found or empty if no initial sample was found for the given account, device, and experiment
* @return A future that completes when the final state has been stored; yields a finished sample if an initial sample
* was found or empty if no initial sample was found for the given account, device, and experiment. The future may
* with a {@link JsonProcessingException} if the initial state could not be read or the final state could not be
* written as a JSON text.
*
* @param <T> the type of state object for this sample
*
* @throws JsonProcessingException if the given {@code finalState} could not be serialized as a JSON text
*/
public <T> CompletableFuture<PushNotificationExperimentSample<T>> recordFinalState(final UUID accountIdentifier,
final byte deviceId,
final String experimentName,
final T finalState) throws JsonProcessingException {
final T finalState) {
CompletableFuture<String> finalStateJsonFuture;
// Process the final state JSON on the calling thread, but inside a CompletionStage so there's just one "channel"
// for reporting JSON exceptions. The alternative is to `throw JsonProcessingException`, but then callers would have
// to both catch the exception when calling this method and also watch the returned future for the same exception.
try {
finalStateJsonFuture =
CompletableFuture.completedFuture(SystemMapper.jsonMapper().writeValueAsString(finalState));
} catch (final JsonProcessingException e) {
finalStateJsonFuture = CompletableFuture.failedFuture(e);
}
final AttributeValue aciAndDeviceIdAttributeValue = buildSortKey(accountIdentifier, deviceId);
return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName),
ATTR_ACI_AND_DEVICE_ID, aciAndDeviceIdAttributeValue))
// `UpdateItem` will, by default, create a new item if one does not already exist for the given primary key. We
// want update-only-if-exists behavior, though, and so check that there's already an existing item for this ACI
// and device ID.
.conditionExpression("#aciAndDeviceId = :aciAndDeviceId")
.updateExpression("SET #finalState = if_not_exists(#finalState, :finalState)")
.expressionAttributeNames(Map.of(
"#aciAndDeviceId", ATTR_ACI_AND_DEVICE_ID,
"#finalState", ATTR_FINAL_STATE))
.expressionAttributeValues(Map.of(
":aciAndDeviceId", aciAndDeviceIdAttributeValue,
":finalState", AttributeValue.fromS(SystemMapper.jsonMapper().writeValueAsString(finalState))))
.returnValues(ReturnValue.ALL_NEW)
.build())
.thenApply(updateItemResponse -> {
try {
final boolean inExperimentGroup = updateItemResponse.attributes().get(ATTR_IN_EXPERIMENT_GROUP).bool();
return finalStateJsonFuture.thenCompose(finalStateJson -> {
return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName),
ATTR_ACI_AND_DEVICE_ID, aciAndDeviceIdAttributeValue))
// `UpdateItem` will, by default, create a new item if one does not already exist for the given primary key. We
// want update-only-if-exists behavior, though, and so check that there's already an existing item for this ACI
// and device ID.
.conditionExpression("#aciAndDeviceId = :aciAndDeviceId")
.updateExpression("SET #finalState = if_not_exists(#finalState, :finalState)")
.expressionAttributeNames(Map.of(
"#aciAndDeviceId", ATTR_ACI_AND_DEVICE_ID,
"#finalState", ATTR_FINAL_STATE))
.expressionAttributeValues(Map.of(
":aciAndDeviceId", aciAndDeviceIdAttributeValue,
":finalState", AttributeValue.fromS(finalStateJson)))
.returnValues(ReturnValue.ALL_NEW)
.build())
.thenApply(updateItemResponse -> {
try {
final boolean inExperimentGroup = updateItemResponse.attributes().get(ATTR_IN_EXPERIMENT_GROUP).bool();
@SuppressWarnings("unchecked") final T parsedInitialState =
(T) parseState(updateItemResponse.attributes().get(ATTR_INITIAL_STATE).s(), finalState.getClass());
@SuppressWarnings("unchecked") final T parsedInitialState =
(T) parseState(updateItemResponse.attributes().get(ATTR_INITIAL_STATE).s(), finalState.getClass());
@SuppressWarnings("unchecked") final T parsedFinalState =
(T) parseState(updateItemResponse.attributes().get(ATTR_FINAL_STATE).s(), finalState.getClass());
@SuppressWarnings("unchecked") final T parsedFinalState =
(T) parseState(updateItemResponse.attributes().get(ATTR_FINAL_STATE).s(), finalState.getClass());
return new PushNotificationExperimentSample<>(inExperimentGroup, parsedInitialState, parsedFinalState);
} catch (final JsonProcessingException e) {
throw ExceptionUtils.wrap(e);
}
});
return new PushNotificationExperimentSample<>(accountIdentifier, deviceId, inExperimentGroup, parsedInitialState, parsedFinalState);
} catch (final JsonProcessingException e) {
throw ExceptionUtils.wrap(e);
}
});
});
}
/**
* Returns a publisher across all samples pending a final state for a given experiment.
* Returns a publisher across all samples for a given experiment.
*
* @param experimentName the name of the experiment for which to retrieve samples pending a final state
*
* @return a publisher across all samples pending a final state for a given experiment
*/
public Flux<Tuple2<UUID, Byte>> getDevicesPendingFinalState(final String experimentName) {
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#experiment = :experiment")
.filterExpression("attribute_not_exists(#finalState)")
.expressionAttributeNames(Map.of(
"#experiment", KEY_EXPERIMENT_NAME,
"#finalState", ATTR_FINAL_STATE))
.expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName)))
.projectionExpression(ATTR_ACI_AND_DEVICE_ID)
.build())
.items())
.map(item -> parseSortKey(item.get(ATTR_ACI_AND_DEVICE_ID)));
}
/**
* Returns a publisher across all finished samples (i.e. samples with a recorded final state) for a given experiment.
*
* @param experimentName the name of the experiment for which to retrieve finished samples
* @param experimentName the name of the experiment for which to fetch samples
* @param stateClass the type of state object for sample in the given experiment
* @param totalSegments the number of segments into which the scan of the backing data store will be divided
*
* @return a publisher across all finished samples for the given experiment
* @return a publisher of tuples of ACI, device ID, and sample for all samples associated with the given experiment
*
* @param <T> the type of the sample's state objects
*
* @see <a href="https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Scan.html#Scan.ParallelScan">Working with scans - Parallel scan</a>
*/
public <T> Flux<PushNotificationExperimentSample<T>> getFinishedSamples(final String experimentName,
final Class<T> stateClass) {
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#experiment = :experiment")
.filterExpression("attribute_exists(#finalState)")
.expressionAttributeNames(Map.of(
"#experiment", KEY_EXPERIMENT_NAME,
"#finalState", ATTR_FINAL_STATE))
public <T> Flux<PushNotificationExperimentSample<T>> getSamples(final String experimentName,
final Class<T> stateClass,
final int totalSegments,
final Scheduler scheduler) {
// Note that we're using a DynamoDB Scan operation instead of a Query. A Query would allow us to limit the search
// space to a specific experiment, but doesn't allow us to use segments. A Scan will always inspect all items in the
// table, but allows us to segment the search. Since we're generally calling this method in conjunction with "…and
// record a final state for the sample," distributing reads/writes across shards helps us avoid per-partition
// throughput limits. If we wind up with many concurrent experiments, it may be worthwhile to revisit this decision.
if (totalSegments < 1) {
throw new IllegalArgumentException("Total number of segments must be positive");
}
return Flux.range(0, totalSegments)
.parallel()
.runOn(scheduler)
.flatMap(segment -> getSamplesFromSegment(experimentName, stateClass, segment, totalSegments))
.sequential();
}
private <T> Flux<PushNotificationExperimentSample<T>> getSamplesFromSegment(final String experimentName,
final Class<T> stateClass,
final int segment,
final int totalSegments) {
return Flux.from(dynamoDbAsyncClient.scanPaginator(ScanRequest.builder()
.tableName(tableName)
.segment(segment)
.totalSegments(totalSegments)
.filterExpression("#experiment = :experiment")
.expressionAttributeNames(Map.of("#experiment", KEY_EXPERIMENT_NAME))
.expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName)))
.build())
.items())
.build())
.items())
.handle((item, sink) -> {
try {
final Tuple2<UUID, Byte> aciAndDeviceId = parseSortKey(item.get(ATTR_ACI_AND_DEVICE_ID));
final boolean inExperimentGroup = item.get(ATTR_IN_EXPERIMENT_GROUP).bool();
final T initialState = parseState(item.get(ATTR_INITIAL_STATE).s(), stateClass);
final T finalState = parseState(item.get(ATTR_FINAL_STATE).s(), stateClass);
final T finalState = item.get(ATTR_FINAL_STATE) != null
? parseState(item.get(ATTR_FINAL_STATE).s(), stateClass)
: null;
sink.next(new PushNotificationExperimentSample<>(inExperimentGroup, initialState, finalState));
sink.next(new PushNotificationExperimentSample<>(aciAndDeviceId.getT1(), aciAndDeviceId.getT2(), inExperimentGroup, initialState, finalState));
} catch (final JsonProcessingException e) {
sink.error(e);
}

View File

@@ -4,6 +4,8 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.core.Application;
import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
@@ -12,17 +14,14 @@ import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSample;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuples;
import reactor.core.scheduler.Schedulers;
import reactor.util.retry.Retry;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import javax.annotation.Nullable;
import java.time.Duration;
import java.util.UUID;
public class FinishPushNotificationExperimentCommand<T> extends AbstractCommandWithDependencies {
@@ -33,6 +32,18 @@ public class FinishPushNotificationExperimentCommand<T> extends AbstractCommandW
@VisibleForTesting
static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency";
@VisibleForTesting
static final String SEGMENT_COUNT_ARGUMENT = "segments";
private static final String SAMPLES_READ_COUNTER_NAME =
MetricsUtil.name(FinishPushNotificationExperimentCommand.class, "samplesRead");
private static final Counter ACCOUNT_READ_COUNTER =
Metrics.counter(MetricsUtil.name(FinishPushNotificationExperimentCommand.class, "accountRead"));
private static final Counter FINAL_SAMPLE_STORED_COUNTER =
Metrics.counter(MetricsUtil.name(FinishPushNotificationExperimentCommand.class, "finalSampleStored"));
private static final Logger log = LoggerFactory.getLogger(FinishPushNotificationExperimentCommand.class);
public FinishPushNotificationExperimentCommand(final String name,
@@ -57,6 +68,13 @@ public class FinishPushNotificationExperimentCommand<T> extends AbstractCommandW
.dest(MAX_CONCURRENCY_ARGUMENT)
.setDefault(DEFAULT_MAX_CONCURRENCY)
.help("Max concurrency for DynamoDB operations");
subparser.addArgument("--segments")
.type(Integer.class)
.dest(SEGMENT_COUNT_ARGUMENT)
.required(false)
.setDefault(16)
.help("The total number of segments for a DynamoDB scan");
}
@Override
@@ -69,6 +87,7 @@ public class FinishPushNotificationExperimentCommand<T> extends AbstractCommandW
experimentFactory.buildExperiment(commandDependencies, configuration);
final int maxConcurrency = namespace.getInt(MAX_CONCURRENCY_ARGUMENT);
final int segments = namespace.getInt(SEGMENT_COUNT_ARGUMENT);
log.info("Finishing \"{}\" with max concurrency: {}", experiment.getExperimentName(), maxConcurrency);
@@ -76,48 +95,44 @@ public class FinishPushNotificationExperimentCommand<T> extends AbstractCommandW
final PushNotificationExperimentSamples pushNotificationExperimentSamples = commandDependencies.pushNotificationExperimentSamples();
final Flux<PushNotificationExperimentSample<T>> finishedSamples =
pushNotificationExperimentSamples.getDevicesPendingFinalState(experiment.getExperimentName())
.flatMap(accountIdentifierAndDeviceId ->
Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(accountIdentifierAndDeviceId.getT1()))
pushNotificationExperimentSamples.getSamples(experiment.getExperimentName(),
experiment.getStateClass(),
segments,
Schedulers.parallel())
.doOnNext(sample -> Metrics.counter(SAMPLES_READ_COUNTER_NAME, "final", String.valueOf(sample.finalState() != null)).increment())
.flatMap(sample -> {
if (sample.finalState() == null) {
// We still need to record a final state for this sample
return Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(sample.accountIdentifier()))
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.map(maybeAccount -> Tuples.of(accountIdentifierAndDeviceId.getT1(),
accountIdentifierAndDeviceId.getT2(), maybeAccount)), maxConcurrency)
.map(accountIdentifierAndDeviceIdAndMaybeAccount -> {
final UUID accountIdentifier = accountIdentifierAndDeviceIdAndMaybeAccount.getT1();
final byte deviceId = accountIdentifierAndDeviceIdAndMaybeAccount.getT2();
.doOnNext(ignored -> ACCOUNT_READ_COUNTER.increment())
.flatMap(maybeAccount -> {
final T finalState = experiment.getState(maybeAccount.orElse(null),
maybeAccount.flatMap(account -> account.getDevice(sample.deviceId())).orElse(null));
@Nullable final Account account = accountIdentifierAndDeviceIdAndMaybeAccount.getT3()
.orElse(null);
return Mono.fromFuture(
() -> pushNotificationExperimentSamples.recordFinalState(sample.accountIdentifier(),
sample.deviceId(),
experiment.getExperimentName(),
finalState))
.onErrorResume(ConditionalCheckFailedException.class, throwable -> Mono.empty())
.onErrorResume(JsonProcessingException.class, throwable -> {
log.error("Failed to parse sample state JSON", throwable);
return Mono.empty();
})
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.onErrorResume(throwable -> {
log.warn("Failed to record final state for {}:{} in experiment {}",
sample.accountIdentifier(), sample.deviceId(), experiment.getExperimentName(), throwable);
@Nullable final Device device = accountIdentifierAndDeviceIdAndMaybeAccount.getT3()
.flatMap(a -> a.getDevice(deviceId))
.orElse(null);
return Tuples.of(accountIdentifier, deviceId, experiment.getState(account, device));
})
.flatMap(accountIdentifierAndDeviceIdAndFinalState -> {
final UUID accountIdentifier = accountIdentifierAndDeviceIdAndFinalState.getT1();
final byte deviceId = accountIdentifierAndDeviceIdAndFinalState.getT2();
final T finalState = accountIdentifierAndDeviceIdAndFinalState.getT3();
return Mono.fromFuture(() -> {
try {
return pushNotificationExperimentSamples.recordFinalState(accountIdentifier, deviceId,
experiment.getExperimentName(), finalState);
} catch (final JsonProcessingException e) {
throw new RuntimeException(e);
}
})
.onErrorResume(ConditionalCheckFailedException.class, throwable -> Mono.empty())
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.onErrorResume(throwable -> {
log.warn("Failed to record final state for {}:{} in experiment {}",
accountIdentifier, deviceId, experiment.getExperimentName(), throwable);
return Mono.empty();
});
}, maxConcurrency)
.flatMap(Mono::justOrEmpty);
return Mono.empty();
})
.doOnSuccess(ignored -> FINAL_SAMPLE_STORED_COUNTER.increment());
});
} else {
return Mono.just(sample);
}
}, maxConcurrency);
experiment.analyzeResults(finishedSamples);
}