Add overrides configuration to LocalDynamoDbFactory

This commit is contained in:
Chris Eager
2025-07-23 09:51:37 -05:00
committed by Chris Eager
parent 83d19ac8ed
commit c99b1cada1
2 changed files with 79 additions and 32 deletions

View File

@@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import java.util.Optional;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
@@ -19,6 +21,8 @@ public class LocalDynamoDbFactory implements DynamoDbClientFactory {
private static final DynamoDbExtension EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.values());
private boolean initExtension = true;
/**
* If true, tables will be created the first time a DynamoDB client is built.
* <p>
@@ -27,34 +31,56 @@ public class LocalDynamoDbFactory implements DynamoDbClientFactory {
@JsonProperty
boolean initTables = true;
public LocalDynamoDbFactory() {
try {
EXTENSION.beforeAll(null);
} catch (Exception e) {
throw new RuntimeException(e);
}
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
EXTENSION.close();
} catch (Throwable e) {
throw new RuntimeException(e);
}
}));
}
/**
* If specified, will be provided to {@link DynamoDbExtension} to use instead of its embedded container
*/
@Nullable
@JsonProperty
DynamoDbLocalOverrides overrides;
@Override
public DynamoDbClient buildSyncClient(final AwsCredentialsProvider awsCredentialsProvider, final MetricPublisher metricPublisher) {
initExtensionIfNecessary();
initTablesIfNecessary();
return EXTENSION.getDynamoDbClient();
}
@Override
public DynamoDbAsyncClient buildAsyncClient(final AwsCredentialsProvider awsCredentialsProvider, final MetricPublisher metricPublisher) {
initExtensionIfNecessary();
initTablesIfNecessary();
return EXTENSION.getDynamoDbAsyncClient();
}
private void initExtensionIfNecessary() {
if (initExtension) {
try {
Optional.ofNullable(overrides)
.ifPresent(o -> {
Optional.ofNullable(o.endpoint).ifPresent(EXTENSION::setEndpointOverride);
Optional.ofNullable(o.region).ifPresent(EXTENSION::setRegion);
Optional.ofNullable(o.awsCredentialsProvider).ifPresent(p -> EXTENSION.setAwsCredentialsProvider(p.build()));
});
EXTENSION.beforeAll(null);
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
EXTENSION.close();
} catch (Throwable e) {
throw new RuntimeException(e);
}
}));
initExtension = false;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
private void initTablesIfNecessary() {
try {
if (initTables) {
@@ -65,4 +91,6 @@ public class LocalDynamoDbFactory implements DynamoDbClientFactory {
throw new RuntimeException(e);
}
}
private record DynamoDbLocalOverrides(@Nullable String endpoint, @Nullable AwsCredentialsProviderFactory awsCredentialsProvider, @Nullable String region) {}
}

View File

@@ -9,15 +9,17 @@ import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.utility.DockerImageName;
import org.whispersystems.textsecuregcm.util.TestcontainersImages;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
@@ -31,7 +33,7 @@ import software.amazon.awssdk.services.dynamodb.model.LocalSecondaryIndex;
import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput;
import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException;
public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback, AfterEachCallback, AfterAllCallback, ExtensionContext.Store.CloseableResource {
public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback, AfterEachCallback, ExtensionContext.Store.CloseableResource {
public interface TableSchema {
String tableName();
@@ -51,6 +53,8 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback,
List<LocalSecondaryIndex> localSecondaryIndexes
) implements TableSchema { }
private static final Logger logger = LoggerFactory.getLogger(DynamoDbExtension.class);
static final ProvisionedThroughput DEFAULT_PROVISIONED_THROUGHPUT = ProvisionedThroughput.builder()
.readCapacityUnits(20L)
.writeCapacityUnits(20L)
@@ -62,14 +66,31 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback,
.withExposedPorts(CONTAINER_PORT)
.withCommand("-jar DynamoDBLocal.jar -inMemory -sharedDb -disableTelemetry");
// These are static to simplify configuration in WhisperServerServiceTest
private static String endpointOverride;
private static Region region = Region.of("local");
private static AwsCredentialsProvider awsCredentialsProvider = StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test"));
private static DynamoDbClient dynamoDb;
private static DynamoDbAsyncClient dynamoDbAsync;
private final List<TableSchema> schemas;
private DynamoDbClient dynamoDb;
private DynamoDbAsyncClient dynamoDbAsync;
public DynamoDbExtension(TableSchema... schemas) {
this.schemas = List.of(schemas);
}
public void setEndpointOverride(String endpointOverride) {
DynamoDbExtension.endpointOverride = endpointOverride;
}
public void setRegion(String region) {
DynamoDbExtension.region = Region.of(region);
}
public void setAwsCredentialsProvider(AwsCredentialsProvider awsCredentialsProvider) {
DynamoDbExtension.awsCredentialsProvider = awsCredentialsProvider;
}
/**
* Starts the DynamoDB server
*/
@@ -111,19 +132,15 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback,
});
}
@Override
public void afterAll(ExtensionContext context) throws Exception {
dynamoDb.close();
dynamoDbAsync.close();
}
@Override
public void close() throws Throwable {
stopServer();
}
private void startServer() {
dynamoDbContainer.start();
if (endpointOverride == null) {
dynamoDbContainer.start();
}
initializeClient();
}
@@ -153,6 +170,7 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback,
}
private void createTables() {
logger.debug("Creating tables");
schemas.forEach(this::createTable);
}
@@ -182,17 +200,18 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback,
}
private void initializeClient() {
final URI endpoint = URI.create(
String.format("http://%s:%d", dynamoDbContainer.getHost(), dynamoDbContainer.getMappedPort(CONTAINER_PORT)));
final URI endpoint = endpointOverride == null ?
URI.create(String.format("http://%s:%d", dynamoDbContainer.getHost(), dynamoDbContainer.getMappedPort(CONTAINER_PORT)))
: URI.create(endpointOverride);
dynamoDb = DynamoDbClient.builder()
.region(Region.of("local"))
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test")))
.region(region)
.credentialsProvider(awsCredentialsProvider)
.endpointOverride(endpoint)
.build();
dynamoDbAsync = DynamoDbAsyncClient.builder()
.region(Region.of("local"))
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test")))
.region(region)
.credentialsProvider(awsCredentialsProvider)
.endpointOverride(endpoint)
.build();
}