Add copy endpoint to ArchiveController

Co-authored-by: Jonathan Klabunde Tomer <125505367+jkt-signal@users.noreply.github.com>
Co-authored-by: Chris Eager <79161849+eager-signal@users.noreply.github.com>
This commit is contained in:
ravi-signal
2023-11-28 11:45:41 -06:00
committed by GitHub
parent 1da3f96d10
commit 202dd8e92d
24 changed files with 1918 additions and 248 deletions

View File

@@ -8,27 +8,31 @@ package org.whispersystems.textsecuregcm.backup;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.commons.lang3.RandomUtils;
import org.assertj.core.api.ThrowableAssert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
@@ -36,28 +40,28 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation;
import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser;
import org.whispersystems.textsecuregcm.backup.BackupManager.BackupInfo;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.TestClock;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
public class BackupManagerTest {
@RegisterExtension
private static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
DynamoDbExtensionSchema.Tables.BACKUPS);
public static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
DynamoDbExtensionSchema.Tables.BACKUPS,
DynamoDbExtensionSchema.Tables.BACKUP_MEDIA);
private final TestClock testClock = TestClock.now();
private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(testClock);
private final TusBackupCredentialGenerator tusCredentialGenerator = mock(TusBackupCredentialGenerator.class);
private final Cdn3BackupCredentialGenerator tusCredentialGenerator = mock(Cdn3BackupCredentialGenerator.class);
private final RemoteStorageManager remoteStorageManager = mock(RemoteStorageManager.class);
private final byte[] backupKey = RandomUtils.nextBytes(32);
private final UUID aci = UUID.randomUUID();
@@ -68,16 +72,19 @@ public class BackupManagerTest {
reset(tusCredentialGenerator);
testClock.unpin();
this.backupManager = new BackupManager(
new BackupsDb(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.BACKUPS.tableName(), DynamoDbExtensionSchema.Tables.BACKUP_MEDIA.tableName(),
testClock),
backupAuthTestUtil.params,
tusCredentialGenerator,
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.BACKUPS.tableName(),
remoteStorageManager,
Map.of(3, "cdn3.example.org/attachments"),
testClock);
}
@ParameterizedTest
@EnumSource(mode = EnumSource.Mode.EXCLUDE, names = {"NONE"})
public void createBackup(final BackupTier backupTier) throws InvalidInputException, VerificationFailedException {
public void createBackup(final BackupTier backupTier) {
final Instant now = Instant.ofEpochSecond(Duration.ofDays(1).getSeconds());
testClock.pin(now);
@@ -89,18 +96,18 @@ public class BackupManagerTest {
verify(tusCredentialGenerator, times(1))
.generateUpload(encodedBackupId, BackupManager.MESSAGE_BACKUP_NAME);
final BackupInfo info = backupManager.backupInfo(backupUser).join();
final BackupManager.BackupInfo info = backupManager.backupInfo(backupUser).join();
assertThat(info.backupSubdir()).isEqualTo(encodedBackupId);
assertThat(info.messageBackupKey()).isEqualTo(BackupManager.MESSAGE_BACKUP_NAME);
assertThat(info.mediaUsedSpace()).isEqualTo(Optional.empty());
// Check that the initial expiration times are the initial write times
checkExpectedExpirations(now, backupTier == BackupTier.MEDIA ? now : null, backupUser.backupId());
checkExpectedExpirations(now, backupTier == BackupTier.MEDIA ? now : null, backupUser);
}
@ParameterizedTest
@EnumSource(mode = EnumSource.Mode.EXCLUDE, names = {"NONE"})
public void ttlRefresh(final BackupTier backupTier) throws InvalidInputException, VerificationFailedException {
public void ttlRefresh(final BackupTier backupTier) {
final AuthenticatedBackupUser backupUser = backupUser(RandomUtils.nextBytes(16), backupTier);
final Instant tstart = Instant.ofEpochSecond(1).plus(Duration.ofDays(1));
@@ -117,12 +124,12 @@ public class BackupManagerTest {
checkExpectedExpirations(
tnext,
backupTier == BackupTier.MEDIA ? tnext : null,
backupUser.backupId());
backupUser);
}
@ParameterizedTest
@EnumSource(mode = EnumSource.Mode.EXCLUDE, names = {"NONE"})
public void createBackupRefreshesTtl(final BackupTier backupTier) throws VerificationFailedException {
public void createBackupRefreshesTtl(final BackupTier backupTier) {
final Instant tstart = Instant.ofEpochSecond(1).plus(Duration.ofDays(1));
final Instant tnext = tstart.plus(Duration.ofSeconds(1));
@@ -139,7 +146,7 @@ public class BackupManagerTest {
checkExpectedExpirations(
tnext,
backupTier == BackupTier.MEDIA ? tnext : null,
backupUser.backupId());
backupUser);
}
@Test
@@ -151,9 +158,10 @@ public class BackupManagerTest {
final byte[] signature = keyPair.getPrivateKey().calculateSignature(presentation.serialize());
// haven't set a public key yet
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(unwrapExceptions(() -> backupManager.authenticateBackupUser(presentation, signature)))
.extracting(ex -> ex.getStatus().getCode())
assertThat(CompletableFutureTestUtil.assertFailsWithCause(
StatusRuntimeException.class,
backupManager.authenticateBackupUser(presentation, signature))
.getStatus().getCode())
.isEqualTo(Status.NOT_FOUND.getCode());
}
@@ -170,9 +178,10 @@ public class BackupManagerTest {
backupManager.setPublicKey(presentation, signature1, keyPair1.getPublicKey()).join();
// shouldn't be able to set a different public key
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(unwrapExceptions(() -> backupManager.setPublicKey(presentation, signature2, keyPair2.getPublicKey())))
.extracting(ex -> ex.getStatus().getCode())
assertThat(CompletableFutureTestUtil.assertFailsWithCause(
StatusRuntimeException.class,
backupManager.setPublicKey(presentation, signature2, keyPair2.getPublicKey()))
.getStatus().getCode())
.isEqualTo(Status.UNAUTHENTICATED.getCode());
// should be able to set the same public key again (noop)
@@ -193,16 +202,17 @@ public class BackupManagerTest {
// shouldn't be able to set a public key with an invalid signature
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(unwrapExceptions(() -> backupManager.setPublicKey(presentation, wrongSignature, keyPair.getPublicKey())))
.isThrownBy(() -> backupManager.setPublicKey(presentation, wrongSignature, keyPair.getPublicKey()))
.extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.UNAUTHENTICATED.getCode());
backupManager.setPublicKey(presentation, signature, keyPair.getPublicKey()).join();
// shouldn't be able to authenticate with an invalid signature
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(unwrapExceptions(() -> backupManager.authenticateBackupUser(presentation, wrongSignature)))
.extracting(ex -> ex.getStatus().getCode())
assertThat(CompletableFutureTestUtil.assertFailsWithCause(
StatusRuntimeException.class,
backupManager.authenticateBackupUser(presentation, wrongSignature))
.getStatus().getCode())
.isEqualTo(Status.UNAUTHENTICATED.getCode());
// correct signature
@@ -212,11 +222,12 @@ public class BackupManagerTest {
}
@Test
public void credentialExpiration() throws InvalidInputException, VerificationFailedException {
public void credentialExpiration() throws VerificationFailedException {
// credential for 1 day after epoch
testClock.pin(Instant.ofEpochSecond(1).plus(Duration.ofDays(1)));
final BackupAuthCredentialPresentation oldCredential = backupAuthTestUtil.getPresentation(BackupTier.MESSAGES, backupKey, aci);
final BackupAuthCredentialPresentation oldCredential = backupAuthTestUtil.getPresentation(BackupTier.MESSAGES,
backupKey, aci);
final ECKeyPair keyPair = Curve.generateKeyPair();
final byte[] signature = keyPair.getPrivateKey().calculateSignature(oldCredential.serialize());
backupManager.setPublicKey(oldCredential, signature, keyPair.getPublicKey()).join();
@@ -231,28 +242,95 @@ public class BackupManagerTest {
// should be rejected the day after that
testClock.pin(Instant.ofEpochSecond(1).plus(Duration.ofDays(3)));
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(unwrapExceptions(() -> backupManager.authenticateBackupUser(oldCredential, signature)))
.extracting(ex -> ex.getStatus().getCode())
assertThat(CompletableFutureTestUtil.assertFailsWithCause(
StatusRuntimeException.class,
backupManager.authenticateBackupUser(oldCredential, signature))
.getStatus().getCode())
.isEqualTo(Status.UNAUTHENTICATED.getCode());
}
@Test
public void copySuccess() {
final AuthenticatedBackupUser backupUser = backupUser(RandomUtils.nextBytes(16), BackupTier.MEDIA);
when(tusCredentialGenerator.generateUpload(any(), any()))
.thenReturn(new MessageBackupUploadDescriptor(3, "def", Collections.emptyMap(), ""));
when(remoteStorageManager.copy(eq(URI.create("cdn3.example.org/attachments/abc")), eq(100), any(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
final BackupManager.StorageDescriptor copied = backupManager.copyToBackup(
backupUser, 3, "abc", 100, mock(MediaEncryptionParameters.class),
"def".getBytes(StandardCharsets.UTF_8)).join();
assertThat(copied.cdn()).isEqualTo(3);
assertThat(copied.key()).isEqualTo("def".getBytes(StandardCharsets.UTF_8));
final Map<String, AttributeValue> backup = getBackupItem(backupUser);
final long bytesUsed = AttributeValues.getLong(backup, BackupsDb.ATTR_MEDIA_BYTES_USED, 0L);
assertThat(bytesUsed).isEqualTo(100);
final long mediaCount = AttributeValues.getLong(backup, BackupsDb.ATTR_MEDIA_COUNT, 0L);
assertThat(mediaCount).isEqualTo(1);
final Map<String, AttributeValue> mediaItem = getBackupMediaItem(backupUser,
"def".getBytes(StandardCharsets.UTF_8));
final long mediaLength = AttributeValues.getLong(mediaItem, BackupsDb.ATTR_LENGTH, 0L);
assertThat(mediaLength).isEqualTo(100L);
}
@Test
public void copyFailure() {
final AuthenticatedBackupUser backupUser = backupUser(RandomUtils.nextBytes(16), BackupTier.MEDIA);
when(tusCredentialGenerator.generateUpload(any(), any()))
.thenReturn(new MessageBackupUploadDescriptor(3, "def", Collections.emptyMap(), ""));
when(remoteStorageManager.copy(eq(URI.create("cdn3.example.org/attachments/abc")), eq(100), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new SourceObjectNotFoundException()));
CompletableFutureTestUtil.assertFailsWithCause(SourceObjectNotFoundException.class,
backupManager.copyToBackup(
backupUser,
3, "abc", 100,
mock(MediaEncryptionParameters.class),
"def".getBytes(StandardCharsets.UTF_8)));
final Map<String, AttributeValue> backup = getBackupItem(backupUser);
assertThat(AttributeValues.getLong(backup, BackupsDb.ATTR_MEDIA_BYTES_USED, -1L)).isEqualTo(0L);
assertThat(AttributeValues.getLong(backup, BackupsDb.ATTR_MEDIA_COUNT, -1L)).isEqualTo(0L);
final Map<String, AttributeValue> media = getBackupMediaItem(backupUser, "def".getBytes(StandardCharsets.UTF_8));
assertThat(media).isEmpty();
}
private Map<String, AttributeValue> getBackupItem(final AuthenticatedBackupUser backupUser) {
return DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder()
.tableName(DynamoDbExtensionSchema.Tables.BACKUPS.tableName())
.key(Map.of(BackupsDb.KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId(backupUser.backupId()))))
.build())
.item();
}
private Map<String, AttributeValue> getBackupMediaItem(final AuthenticatedBackupUser backupUser,
final byte[] mediaId) {
return DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder()
.tableName(DynamoDbExtensionSchema.Tables.BACKUP_MEDIA.tableName())
.key(Map.of(
BackupsDb.KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId(backupUser.backupId())),
BackupsDb.KEY_MEDIA_ID, AttributeValues.b(mediaId)))
.build())
.item();
}
private void checkExpectedExpirations(
final Instant expectedExpiration,
final @Nullable Instant expectedMediaExpiration,
final byte[] backupId) {
final GetItemResponse item = DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder()
.tableName(DynamoDbExtensionSchema.Tables.BACKUPS.tableName())
.key(Map.of(BackupManager.KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId(backupId))))
.build());
assertThat(item.hasItem()).isTrue();
final Instant refresh = Instant.ofEpochSecond(Long.parseLong(item.item().get(BackupManager.ATTR_LAST_REFRESH).n()));
final AuthenticatedBackupUser backupUser) {
final Map<String, AttributeValue> item = getBackupItem(backupUser);
final Instant refresh = Instant.ofEpochSecond(Long.parseLong(item.get(BackupsDb.ATTR_LAST_REFRESH).n()));
assertThat(refresh).isEqualTo(expectedExpiration);
if (expectedMediaExpiration == null) {
assertThat(item.item()).doesNotContainKey(BackupManager.ATTR_LAST_MEDIA_REFRESH);
assertThat(item).doesNotContainKey(BackupsDb.ATTR_LAST_MEDIA_REFRESH);
} else {
assertThat(Instant.ofEpochSecond(Long.parseLong(item.item().get(BackupManager.ATTR_LAST_MEDIA_REFRESH).n())))
assertThat(Instant.ofEpochSecond(Long.parseLong(item.get(BackupsDb.ATTR_LAST_MEDIA_REFRESH).n())))
.isEqualTo(expectedMediaExpiration);
}
}
@@ -268,17 +346,4 @@ public class BackupManagerTest {
private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupTier backupTier) {
return new AuthenticatedBackupUser(backupId, backupTier);
}
private <T> ThrowableAssert.ThrowingCallable unwrapExceptions(final Supplier<CompletableFuture<T>> f) {
return () -> {
try {
f.get().join();
} catch (Exception e) {
if (ExceptionUtils.unwrap(e) instanceof StatusRuntimeException ex) {
throw ex;
}
throw e;
}
};
}
}

View File

@@ -0,0 +1,93 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.backup;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import org.apache.commons.lang3.RandomUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.TestClock;
public class BackupsDbTest {
@RegisterExtension
public static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
DynamoDbExtensionSchema.Tables.BACKUPS,
DynamoDbExtensionSchema.Tables.BACKUP_MEDIA);
private final TestClock testClock = TestClock.now();
private BackupsDb backupsDb;
@BeforeEach
public void setup() {
testClock.unpin();
backupsDb = new BackupsDb(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.BACKUPS.tableName(), DynamoDbExtensionSchema.Tables.BACKUP_MEDIA.tableName(),
testClock);
}
@Test
public void trackMediaIdempotent() {
final AuthenticatedBackupUser backupUser = backupUser(RandomUtils.nextBytes(16), BackupTier.MEDIA);
this.backupsDb.trackMedia(backupUser, "abc".getBytes(StandardCharsets.UTF_8), 100).join();
assertDoesNotThrow(() ->
this.backupsDb.trackMedia(backupUser, "abc".getBytes(StandardCharsets.UTF_8), 100).join());
}
@Test
public void trackMediaLengthChange() {
final AuthenticatedBackupUser backupUser = backupUser(RandomUtils.nextBytes(16), BackupTier.MEDIA);
this.backupsDb.trackMedia(backupUser, "abc".getBytes(StandardCharsets.UTF_8), 100).join();
CompletableFutureTestUtil.assertFailsWithCause(InvalidLengthException.class,
this.backupsDb.trackMedia(backupUser, "abc".getBytes(StandardCharsets.UTF_8), 99));
}
@Test
public void trackMediaStats() {
final AuthenticatedBackupUser backupUser = backupUser(RandomUtils.nextBytes(16), BackupTier.MEDIA);
// add at least one message backup so we can describe it
backupsDb.addMessageBackup(backupUser).join();
int total = 0;
for (int i = 0; i < 5; i++) {
this.backupsDb.trackMedia(backupUser, Integer.toString(i).getBytes(StandardCharsets.UTF_8), i).join();
total += i;
final BackupsDb.BackupDescription description = this.backupsDb.describeBackup(backupUser).join();
assertThat(description.mediaUsedSpace().get()).isEqualTo(total);
}
for (int i = 0; i < 5; i++) {
this.backupsDb.untrackMedia(backupUser, Integer.toString(i).getBytes(StandardCharsets.UTF_8), i).join();
total -= i;
final BackupsDb.BackupDescription description = this.backupsDb.describeBackup(backupUser).join();
assertThat(description.mediaUsedSpace().get()).isEqualTo(total);
}
}
private static byte[] hashedBackupId(final byte[] backupId) {
try {
return Arrays.copyOf(MessageDigest.getInstance("SHA-256").digest(backupId), 16);
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupTier backupTier) {
return new AuthenticatedBackupUser(backupId, backupTier);
}
}

View File

@@ -16,10 +16,10 @@ import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
public class TusBackupCredentialGeneratorTest {
public class Cdn3BackupCredentialGeneratorTest {
@Test
public void uploadGenerator() {
TusBackupCredentialGenerator generator = new TusBackupCredentialGenerator(new TusConfiguration(
Cdn3BackupCredentialGenerator generator = new Cdn3BackupCredentialGenerator(new TusConfiguration(
new SecretBytes(RandomUtils.nextBytes(32)),
"https://example.org/upload"));
@@ -33,7 +33,7 @@ public class TusBackupCredentialGeneratorTest {
@Test
public void readCredential() {
TusBackupCredentialGenerator generator = new TusBackupCredentialGenerator(new TusConfiguration(
Cdn3BackupCredentialGenerator generator = new Cdn3BackupCredentialGenerator(new TusConfiguration(
new SecretBytes(RandomUtils.nextBytes(32)),
"https://example.org/upload"));

View File

@@ -0,0 +1,185 @@
package org.whispersystems.textsecuregcm.backup;
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.get;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.util.Arrays;
import java.util.Collections;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.Mac;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
@ExtendWith(DropwizardExtensionsSupport.class)
public class Cdn3RemoteStorageManagerTest {
private static byte[] HMAC_KEY = getRandomBytes(32);
private static byte[] AES_KEY = getRandomBytes(32);
private static byte[] IV = getRandomBytes(16);
@RegisterExtension
private final WireMockExtension wireMock = WireMockExtension.newInstance()
.options(wireMockConfig().dynamicPort())
.build();
private static String SMALL_CDN2 = "a small object from cdn2";
private static String SMALL_CDN3 = "a small object from cdn3";
private static String LARGE = "a".repeat(1024 * 1024 * 5);
private RemoteStorageManager remoteStorageManager;
@BeforeEach
public void init() throws CertificateException {
remoteStorageManager = new Cdn3RemoteStorageManager(
Executors.newSingleThreadScheduledExecutor(),
new CircuitBreakerConfiguration(),
new RetryConfiguration(),
Collections.emptyList());
wireMock.stubFor(get(urlEqualTo("/cdn2/source/small"))
.willReturn(aResponse()
.withHeader("Content-Length", Integer.toString(SMALL_CDN2.length()))
.withBody(SMALL_CDN2)));
wireMock.stubFor(get(urlEqualTo("/cdn3/source/small"))
.willReturn(aResponse()
.withHeader("Content-Length", Integer.toString(SMALL_CDN3.length()))
.withBody(SMALL_CDN3)));
wireMock.stubFor(get(urlEqualTo("/cdn3/source/large"))
.willReturn(aResponse()
.withHeader("Content-Length", Integer.toString(LARGE.length()))
.withBody(LARGE)));
wireMock.stubFor(get(urlEqualTo("/cdn3/source/missing"))
.willReturn(aResponse().withStatus(404)));
}
@ParameterizedTest
@ValueSource(ints = {2, 3})
public void copySmall(final int sourceCdn)
throws InvalidAlgorithmParameterException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException {
final String expectedSource = switch (sourceCdn) {
case 2 -> SMALL_CDN2;
case 3 -> SMALL_CDN3;
default -> throw new AssertionError();
};
wireMock.stubFor(post(urlEqualTo("/cdn3/dest"))
.willReturn(aResponse()
.withStatus(201)));
remoteStorageManager.copy(
URI.create(wireMock.url("/cdn" + sourceCdn + "/source/small")),
expectedSource.length(),
new MediaEncryptionParameters(AES_KEY, HMAC_KEY, IV),
new MessageBackupUploadDescriptor(3, "test", Collections.emptyMap(), wireMock.url("/cdn3/dest")))
.toCompletableFuture().join();
final byte[] destBody = wireMock.findAll(postRequestedFor(urlEqualTo("/cdn3/dest"))).get(0).getBody();
assertThat(new String(decrypt(destBody), StandardCharsets.UTF_8))
.isEqualTo(expectedSource);
}
@Test
public void copyLarge()
throws InvalidAlgorithmParameterException, IllegalBlockSizeException, BadPaddingException, InvalidKeyException {
wireMock.stubFor(post(urlEqualTo("/cdn3/dest"))
.willReturn(aResponse()
.withStatus(201)));
final MediaEncryptionParameters params = new MediaEncryptionParameters(AES_KEY, HMAC_KEY, IV);
remoteStorageManager.copy(
URI.create(wireMock.url("/cdn3/source/large")),
LARGE.length(),
params,
new MessageBackupUploadDescriptor(3, "test", Collections.emptyMap(), wireMock.url("/cdn3/dest")))
.toCompletableFuture().join();
final byte[] destBody = wireMock.findAll(postRequestedFor(urlEqualTo("/cdn3/dest"))).get(0).getBody();
assertThat(destBody.length).isEqualTo(new BackupMediaEncrypter(params).outputSize(LARGE.length()));
assertThat(new String(decrypt(destBody), StandardCharsets.UTF_8)).isEqualTo(LARGE);
}
@Test
public void incorrectLength() {
CompletableFutureTestUtil.assertFailsWithCause(InvalidLengthException.class,
remoteStorageManager.copy(
URI.create(wireMock.url("/cdn3/source/small")),
SMALL_CDN3.length() - 1,
new MediaEncryptionParameters(AES_KEY, HMAC_KEY, IV),
new MessageBackupUploadDescriptor(3, "test", Collections.emptyMap(), wireMock.url("/cdn3/dest")))
.toCompletableFuture());
}
@Test
public void sourceMissing() {
CompletableFutureTestUtil.assertFailsWithCause(SourceObjectNotFoundException.class,
remoteStorageManager.copy(
URI.create(wireMock.url("/cdn3/source/missing")),
1,
new MediaEncryptionParameters(AES_KEY, HMAC_KEY, IV),
new MessageBackupUploadDescriptor(3, "test", Collections.emptyMap(), wireMock.url("/cdn3/dest")))
.toCompletableFuture());
}
private byte[] decrypt(final byte[] encrypted)
throws InvalidAlgorithmParameterException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException {
final Mac mac;
try {
mac = Mac.getInstance("HmacSHA256");
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
mac.init(new SecretKeySpec(HMAC_KEY, "HmacSHA256"));
mac.update(encrypted, 0, encrypted.length - mac.getMacLength());
assertArrayEquals(mac.doFinal(),
Arrays.copyOfRange(encrypted, encrypted.length - mac.getMacLength(), encrypted.length));
assertArrayEquals(IV, Arrays.copyOf(encrypted, 16));
final Cipher cipher;
try {
cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
} catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
throw new AssertionError(e);
}
cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(AES_KEY, "AES"), new IvParameterSpec(IV));
return cipher.doFinal(encrypted, IV.length, encrypted.length - IV.length - mac.getMacLength());
}
private static byte[] getRandomBytes(int length) {
byte[] result = new byte[length];
ThreadLocalRandom.current().nextBytes(result);
return result;
}
}