diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommand.java index 65d5faac7..e6214f431 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommand.java @@ -11,6 +11,10 @@ import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; import net.sourceforge.argparse4j.inf.Subparser; import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; @@ -28,6 +32,9 @@ public class EncryptDeviceCreationTimestampCommand extends AbstractSinglePassCra @VisibleForTesting static final String DRY_RUN_ARGUMENT = "dry-run"; + @VisibleForTesting + static final String BUFFER_ARGUMENT = "buffer"; + private static final int MAX_CONCURRENCY = 16; private static final String PROCESSED_ACCOUNT_COUNTER_NAME = @@ -49,14 +56,33 @@ public class EncryptDeviceCreationTimestampCommand extends AbstractSinglePassCra .required(false) .setDefault(true) .help("If true, don't actually update device records"); + + subparser.addArgument("--buffer") + .type(Integer.class) + .dest(BUFFER_ARGUMENT) + .setDefault(16_384) + .help("Records to buffer"); } @Override protected void crawlAccounts(final Flux accounts) { final boolean isDryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT); + final int bufferSize = getNamespace().getInt(BUFFER_ARGUMENT); + final Counter processedAccountCounter = Metrics.counter(PROCESSED_ACCOUNT_COUNTER_NAME, "dryRun", String.valueOf(isDryRun)); + accounts + // We've partially processed enough accounts now that this should speed up the crawler + .filter(a -> a.getDevices().stream().anyMatch(d -> d.getCreatedAtCiphertext() == null || d.getCreatedAtCiphertext().length == 0)) + .buffer(bufferSize) + .map(source -> { + final List shuffled = new ArrayList<>(source); + Collections.shuffle(shuffled); + return shuffled; + }) + .limitRate(2) + .flatMapIterable(Function.identity()) .flatMap(account -> { Mono encryptTimestampMono = isDryRun ? Mono.empty() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommandTest.java index 9aa6ff935..195e95b39 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/EncryptDeviceCreationTimestampCommandTest.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.workers; import net.sourceforge.argparse4j.inf.Namespace; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.signal.libsignal.protocol.IdentityKey; @@ -16,6 +17,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; +import org.whispersystems.textsecuregcm.util.TestRandomUtil; import reactor.core.publisher.Flux; import java.util.List; import java.util.UUID; @@ -28,6 +30,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -46,6 +49,7 @@ public class EncryptDeviceCreationTimestampCommandTest { namespace = mock(Namespace.class); when(namespace.getBoolean(EncryptDeviceCreationTimestampCommand.DRY_RUN_ARGUMENT)).thenReturn(isDryRun); + when(namespace.getInt(EncryptDeviceCreationTimestampCommand.BUFFER_ARGUMENT)).thenReturn(5); } @Override @@ -90,4 +94,46 @@ public class EncryptDeviceCreationTimestampCommandTest { assertNotNull(testAccount.getDevices().getFirst().getCreatedAtCiphertext()); } } + + @Test + void crawlAccountsWithEncryptedTimestamps() { + final Account unencryptedTimestampAccount = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), List.of( + DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + unencryptedTimestampAccount.setIdentityKey(new IdentityKey(ECKeyPair.generate().getPublicKey())); + + final Account encryptedTimestampAccount = AccountsHelper.generateTestAccount("+14152222223", UUID.randomUUID(), UUID.randomUUID(), List.of( + DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + encryptedTimestampAccount.setIdentityKey(new IdentityKey(ECKeyPair.generate().getPublicKey())); + encryptedTimestampAccount.getDevices().getFirst().setCreatedAtCiphertext(TestRandomUtil.nextBytes(56)); + + final Account halfEncryptedTimestampAccount = AccountsHelper.generateTestAccount("+14152222224", UUID.randomUUID(), UUID.randomUUID(), List.of( + DevicesHelper.createDevice(Device.PRIMARY_ID), DevicesHelper.createDevice((byte) 2)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + halfEncryptedTimestampAccount.setIdentityKey(new IdentityKey(ECKeyPair.generate().getPublicKey())); + halfEncryptedTimestampAccount.getDevices().getFirst().setCreatedAtCiphertext(TestRandomUtil.nextBytes(56)); + + final AccountsManager accountsManager = mock(AccountsManager.class); + when(accountsManager.updateAsync(any(), any())).thenAnswer(invocation -> { + final Account account = invocation.getArgument(0); + final Consumer updater = invocation.getArgument(1); + updater.accept(account); + return CompletableFuture.completedFuture(account); + }); + + final EncryptDeviceCreationTimestampCommand encryptDeviceCreationTimestampCommand = + new TestEncryptDeviceCreationTimestampCommand(accountsManager, false); + + assertNull(unencryptedTimestampAccount.getDevices().getFirst().getCreatedAtCiphertext()); + assertNull(halfEncryptedTimestampAccount.getDevices().get(1).getCreatedAtCiphertext()); + + encryptDeviceCreationTimestampCommand.crawlAccounts(Flux.just(unencryptedTimestampAccount, + encryptedTimestampAccount, + halfEncryptedTimestampAccount)); + + verify(accountsManager, times(1)).updateAsync(eq(unencryptedTimestampAccount), any()); + verify(accountsManager, never()).updateAsync(eq(encryptedTimestampAccount), any()); + verify(accountsManager, times(1)).updateAsync(eq(halfEncryptedTimestampAccount), any()); + + assertNotNull(unencryptedTimestampAccount.getDevices().getFirst().getCreatedAtCiphertext()); + assertNotNull(halfEncryptedTimestampAccount.getDevices().get(1).getCreatedAtCiphertext()); + } }