diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/configuration/LocalDynamoDbFactory.java b/service/src/test/java/org/whispersystems/textsecuregcm/configuration/LocalDynamoDbFactory.java index 54e2263cc..f49363b57 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/configuration/LocalDynamoDbFactory.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/configuration/LocalDynamoDbFactory.java @@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.configuration; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; +import java.util.Optional; +import javax.annotation.Nullable; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; @@ -19,6 +21,8 @@ public class LocalDynamoDbFactory implements DynamoDbClientFactory { private static final DynamoDbExtension EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.values()); + private boolean initExtension = true; + /** * If true, tables will be created the first time a DynamoDB client is built. *

@@ -27,34 +31,56 @@ public class LocalDynamoDbFactory implements DynamoDbClientFactory { @JsonProperty boolean initTables = true; - public LocalDynamoDbFactory() { - try { - EXTENSION.beforeAll(null); - } catch (Exception e) { - throw new RuntimeException(e); - } - - Runtime.getRuntime().addShutdownHook(new Thread(() -> { - try { - EXTENSION.close(); - } catch (Throwable e) { - throw new RuntimeException(e); - } - })); - } + /** + * If specified, will be provided to {@link DynamoDbExtension} to use instead of its embedded container + */ + @Nullable + @JsonProperty + DynamoDbLocalOverrides overrides; @Override public DynamoDbClient buildSyncClient(final AwsCredentialsProvider awsCredentialsProvider, final MetricPublisher metricPublisher) { + initExtensionIfNecessary(); initTablesIfNecessary(); + return EXTENSION.getDynamoDbClient(); } @Override public DynamoDbAsyncClient buildAsyncClient(final AwsCredentialsProvider awsCredentialsProvider, final MetricPublisher metricPublisher) { + initExtensionIfNecessary(); initTablesIfNecessary(); + return EXTENSION.getDynamoDbAsyncClient(); } + private void initExtensionIfNecessary() { + if (initExtension) { + try { + Optional.ofNullable(overrides) + .ifPresent(o -> { + Optional.ofNullable(o.endpoint).ifPresent(EXTENSION::setEndpointOverride); + Optional.ofNullable(o.region).ifPresent(EXTENSION::setRegion); + Optional.ofNullable(o.awsCredentialsProvider).ifPresent(p -> EXTENSION.setAwsCredentialsProvider(p.build())); + }); + + EXTENSION.beforeAll(null); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + EXTENSION.close(); + } catch (Throwable e) { + throw new RuntimeException(e); + } + })); + + initExtension = false; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + private void initTablesIfNecessary() { try { if (initTables) { @@ -65,4 +91,6 @@ public class LocalDynamoDbFactory implements DynamoDbClientFactory { throw new RuntimeException(e); } } + + private record DynamoDbLocalOverrides(@Nullable String endpoint, @Nullable AwsCredentialsProviderFactory awsCredentialsProvider, @Nullable String region) {} } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java index de9004707..8cc452d7a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java @@ -9,15 +9,17 @@ import java.net.URI; import java.time.Duration; import java.time.Instant; import java.util.List; -import org.junit.jupiter.api.extension.AfterAllCallback; import org.junit.jupiter.api.extension.AfterEachCallback; import org.junit.jupiter.api.extension.BeforeAllCallback; import org.junit.jupiter.api.extension.BeforeEachCallback; import org.junit.jupiter.api.extension.ExtensionContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.testcontainers.containers.GenericContainer; import org.testcontainers.utility.DockerImageName; import org.whispersystems.textsecuregcm.util.TestcontainersImages; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; @@ -31,7 +33,7 @@ import software.amazon.awssdk.services.dynamodb.model.LocalSecondaryIndex; import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput; import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException; -public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback, AfterEachCallback, AfterAllCallback, ExtensionContext.Store.CloseableResource { +public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback, AfterEachCallback, ExtensionContext.Store.CloseableResource { public interface TableSchema { String tableName(); @@ -51,6 +53,8 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback, List localSecondaryIndexes ) implements TableSchema { } + private static final Logger logger = LoggerFactory.getLogger(DynamoDbExtension.class); + static final ProvisionedThroughput DEFAULT_PROVISIONED_THROUGHPUT = ProvisionedThroughput.builder() .readCapacityUnits(20L) .writeCapacityUnits(20L) @@ -62,14 +66,31 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback, .withExposedPorts(CONTAINER_PORT) .withCommand("-jar DynamoDBLocal.jar -inMemory -sharedDb -disableTelemetry"); + // These are static to simplify configuration in WhisperServerServiceTest + private static String endpointOverride; + private static Region region = Region.of("local"); + private static AwsCredentialsProvider awsCredentialsProvider = StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test")); + private static DynamoDbClient dynamoDb; + private static DynamoDbAsyncClient dynamoDbAsync; + private final List schemas; - private DynamoDbClient dynamoDb; - private DynamoDbAsyncClient dynamoDbAsync; public DynamoDbExtension(TableSchema... schemas) { this.schemas = List.of(schemas); } + public void setEndpointOverride(String endpointOverride) { + DynamoDbExtension.endpointOverride = endpointOverride; + } + + public void setRegion(String region) { + DynamoDbExtension.region = Region.of(region); + } + + public void setAwsCredentialsProvider(AwsCredentialsProvider awsCredentialsProvider) { + DynamoDbExtension.awsCredentialsProvider = awsCredentialsProvider; + } + /** * Starts the DynamoDB server */ @@ -111,19 +132,15 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback, }); } - @Override - public void afterAll(ExtensionContext context) throws Exception { - dynamoDb.close(); - dynamoDbAsync.close(); - } - @Override public void close() throws Throwable { stopServer(); } private void startServer() { - dynamoDbContainer.start(); + if (endpointOverride == null) { + dynamoDbContainer.start(); + } initializeClient(); } @@ -153,6 +170,7 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback, } private void createTables() { + logger.debug("Creating tables"); schemas.forEach(this::createTable); } @@ -182,17 +200,18 @@ public class DynamoDbExtension implements BeforeAllCallback, BeforeEachCallback, } private void initializeClient() { - final URI endpoint = URI.create( - String.format("http://%s:%d", dynamoDbContainer.getHost(), dynamoDbContainer.getMappedPort(CONTAINER_PORT))); + final URI endpoint = endpointOverride == null ? + URI.create(String.format("http://%s:%d", dynamoDbContainer.getHost(), dynamoDbContainer.getMappedPort(CONTAINER_PORT))) + : URI.create(endpointOverride); dynamoDb = DynamoDbClient.builder() - .region(Region.of("local")) - .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test"))) + .region(region) + .credentialsProvider(awsCredentialsProvider) .endpointOverride(endpoint) .build(); dynamoDbAsync = DynamoDbAsyncClient.builder() - .region(Region.of("local")) - .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test"))) + .region(region) + .credentialsProvider(awsCredentialsProvider) .endpointOverride(endpoint) .build(); }