diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommand.java index d9594d901..65d5faac7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommand.java @@ -12,10 +12,12 @@ import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.time.Duration; import net.sourceforge.argparse4j.inf.Subparser; +import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.EncryptDeviceCreationTimestampUtil; import org.whispersystems.textsecuregcm.util.Util; import reactor.core.publisher.Mono; @@ -28,8 +30,6 @@ public class EncryptDeviceCreationTimestampCommand extends AbstractSinglePassCra private static final int MAX_CONCURRENCY = 16; - private static final String ENCRYPTED_CREATION_TIMESTAMP_COUNTER_NAME = - name(EncryptDeviceCreationTimestampCommand.class, "encryptedCreationTimestamp"); private static final String PROCESSED_ACCOUNT_COUNTER_NAME = name(EncryptDeviceCreationTimestampCommand.class, "processedAccount"); @@ -54,35 +54,33 @@ public class EncryptDeviceCreationTimestampCommand extends AbstractSinglePassCra @Override protected void crawlAccounts(final Flux accounts) { final boolean isDryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT); - final Counter encryptedTimestampCounter = - Metrics.counter(ENCRYPTED_CREATION_TIMESTAMP_COUNTER_NAME, "dryRun", String.valueOf(isDryRun)); final Counter processedAccountCounter = Metrics.counter(PROCESSED_ACCOUNT_COUNTER_NAME, "dryRun", String.valueOf(isDryRun)); accounts - .flatMap(account -> - Flux.fromIterable(account.getDevices()) - .flatMap(device -> { - final byte[] createdAtCiphertext = EncryptDeviceCreationTimestampUtil.encrypt( - device.getCreated(), account.getIdentityKey(IdentityType.ACI), - device.getId(), device.getRegistrationId(IdentityType.ACI)); - - final Mono encryptTimestampMono = isDryRun - ? Mono.empty() - : Mono.fromFuture(() -> getCommandDependencies().accountsManager().updateDeviceAsync( - account, device.getId(), d -> d.setCreatedAtCiphertext(createdAtCiphertext)) - .thenRun(Util.NOOP)); - return encryptTimestampMono - .doOnSuccess(_ -> encryptedTimestampCounter.increment()) - .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).maxBackoff(Duration.ofSeconds(4))) - .onErrorResume(throwable -> { - log.warn("Failed to encrypt creation timestamp on device {}, account {}", device.getId(), account.getUuid(), throwable); - return Mono.empty(); - }); - }, MAX_CONCURRENCY) - .then() - .doOnSuccess(_ -> processedAccountCounter.increment()), - MAX_CONCURRENCY) + .flatMap(account -> { + Mono encryptTimestampMono = isDryRun + ? Mono.empty() + : Mono.fromFuture( + () -> getCommandDependencies().accountsManager().updateAsync(account, a -> { + final IdentityKey aciIdentityKey = account.getIdentityKey(IdentityType.ACI); + for (final Device device : a.getDevices()) { + final byte[] createdAtCiphertext = EncryptDeviceCreationTimestampUtil.encrypt( + device.getCreated(), aciIdentityKey, + device.getId(), device.getRegistrationId(IdentityType.ACI)); + device.setCreatedAtCiphertext(createdAtCiphertext); + } + }).thenRun(Util.NOOP)); + return encryptTimestampMono + .doOnSuccess(_ -> processedAccountCounter.increment()) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).maxBackoff(Duration.ofSeconds(4))) + .onErrorResume(throwable -> { + log.warn("Failed to encrypt creation timestamps on account {}", account.getUuid(), throwable); + return Mono.empty(); + }); + }, MAX_CONCURRENCY) .then() .block(); + + log.info("Finished encrypting device timestamps"); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommandTest.java index cfe08d5e3..9aa6ff935 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommandTest.java @@ -8,26 +8,25 @@ package org.whispersystems.textsecuregcm.workers; import net.sourceforge.argparse4j.inf.Namespace; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.MockedStatic; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.EncryptDeviceCreationTimestampUtil; -import org.whispersystems.textsecuregcm.util.TestRandomUtil; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; +import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import reactor.core.publisher.Flux; import java.util.List; +import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyByte; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -63,44 +62,32 @@ public class EncryptDeviceCreationTimestampCommandTest { @ParameterizedTest @ValueSource(booleans = {true, false}) void crawlAccounts(final boolean isDryRun) { - final Account account = mock(Account.class); - final Device device = mock(Device.class); + final String number = "+14152222222"; + final UUID uuid = UUID.randomUUID(); - final IdentityKey identityKey = new IdentityKey(ECKeyPair.generate().getPublicKey()); - final byte deviceId = (byte) 1; - final long createdAt = System.currentTimeMillis(); - final int registrationId = 123; - - when(account.getDevices()).thenReturn(List.of(device)); - when(account.getIdentityKey(IdentityType.ACI)).thenReturn(identityKey); - when(device.getCreated()).thenReturn(createdAt); - when(device.getId()).thenReturn(deviceId); - when(device.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId); + final Account testAccount = AccountsHelper.generateTestAccount(number, uuid, UUID.randomUUID(), List.of( + DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + testAccount.setIdentityKey(new IdentityKey(ECKeyPair.generate().getPublicKey())); final AccountsManager accountsManager = mock(AccountsManager.class); - when(accountsManager.updateDeviceAsync(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(accountsManager.updateAsync(any(), any())).thenAnswer(invocation -> { + final Account account = invocation.getArgument(0); + final Consumer updater = invocation.getArgument(1); + updater.accept(account); + return CompletableFuture.completedFuture(account); + }); final EncryptDeviceCreationTimestampCommand encryptDeviceCreationTimestampCommand = new TestEncryptDeviceCreationTimestampCommand(accountsManager, isDryRun); - try (MockedStatic mockUtil = mockStatic(EncryptDeviceCreationTimestampUtil.class)) { - mockUtil.when(() -> EncryptDeviceCreationTimestampUtil.encrypt(anyLong(), any(), anyByte(), anyInt())) - .thenReturn(TestRandomUtil.nextBytes(56)); + encryptDeviceCreationTimestampCommand.crawlAccounts(Flux.just(testAccount)); - encryptDeviceCreationTimestampCommand.crawlAccounts(Flux.just(account)); - - mockUtil.verify(() -> EncryptDeviceCreationTimestampUtil.encrypt( - eq(createdAt), - eq(identityKey), - eq(deviceId), - eq(registrationId) - )); - - if (isDryRun) { - verify(accountsManager, never()).updateDeviceAsync(any(), anyByte(), any()); - } else { - verify(accountsManager).updateDeviceAsync(eq(account), eq(deviceId), any()); - } + if (isDryRun) { + verify(accountsManager, never()).updateAsync(any(), any()); + assertNull(testAccount.getDevices().getFirst().getCreatedAtCiphertext()); + } else { + verify(accountsManager).updateAsync(eq(testAccount), any()); + assertNotNull(testAccount.getDevices().getFirst().getCreatedAtCiphertext()); } } }