Switch to an async SQS client.

This commit is contained in:
Jon Chambers
2021-07-26 11:40:53 -04:00
committed by Jon Chambers
parent a6066bfc2f
commit 34dbff6786
2 changed files with 87 additions and 118 deletions

View File

@@ -14,23 +14,19 @@ import com.google.common.annotations.VisibleForTesting;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import com.google.common.collect.Iterables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.SqsConfiguration;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkServiceException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
public class DirectoryQueue {
@@ -41,22 +37,38 @@ public class DirectoryQueue {
private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError"));
private final Timer sendMessageBatchTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessageBatch"));
private final List<String> queueUrls;
private final SqsClient sqs;
private final List<String> queueUrls;
private final SqsAsyncClient sqs;
private enum UpdateAction {
ADD("add"),
DELETE("delete");
private final String action;
UpdateAction(final String action) {
this.action = action;
}
public MessageAttributeValue toMessageAttributeValue() {
return MessageAttributeValue.builder().dataType("String").stringValue(action).build();
}
}
public DirectoryQueue(SqsConfiguration sqsConfig) {
StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(AwsBasicCredentials.create(
sqsConfig.getAccessKey(), sqsConfig.getAccessSecret()));
this.queueUrls = sqsConfig.getQueueUrls();
this.sqs = SqsClient.builder()
this.sqs = SqsAsyncClient.builder()
.region(Region.of(sqsConfig.getRegion()))
.credentialsProvider(credentialsProvider)
.build();
}
@VisibleForTesting
DirectoryQueue(final List<String> queueUrls, final SqsClient sqs) {
DirectoryQueue(final List<String> queueUrls, final SqsAsyncClient sqs) {
this.queueUrls = queueUrls;
this.sqs = sqs;
}
@@ -66,58 +78,44 @@ public class DirectoryQueue {
}
public void refreshAccount(final Account account) {
refreshAccounts(List.of(account));
}
public void refreshAccounts(final List<Account> accounts) {
final List<Pair<Account, String>> accountsAndActions = accounts.stream()
.map(account -> new Pair<>(account, account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete"))
.collect(Collectors.toList());
sendUpdateMessages(accountsAndActions);
sendUpdateMessage(account, isDiscoverable(account) ? UpdateAction.ADD : UpdateAction.DELETE);
}
public void deleteAccount(final Account account) {
sendUpdateMessages(List.of(new Pair<>(account, "delete")));
sendUpdateMessage(account, UpdateAction.DELETE);
}
private void sendUpdateMessages(final List<Pair<Account, String>> accountsAndActions) {
private void sendUpdateMessage(final Account account, final UpdateAction action) {
for (final String queueUrl : queueUrls) {
for (final List<Pair<Account, String>> partition : Iterables.partition(accountsAndActions, 10)) {
final List<SendMessageBatchRequestEntry> entries = partition.stream().map(pair -> {
final Account account = pair.first();
final String action = pair.second();
final Timer.Context timerContext = sendMessageBatchTimer.time();
return SendMessageBatchRequestEntry.builder()
.messageBody("-")
.id(UUID.randomUUID().toString())
.messageDeduplicationId(UUID.randomUUID().toString())
.messageGroupId(account.getNumber())
.messageAttributes(Map.of(
"id", MessageAttributeValue.builder().dataType("String").stringValue(account.getNumber()).build(),
"uuid", MessageAttributeValue.builder().dataType("String").stringValue(account.getUuid().toString()).build(),
"action", MessageAttributeValue.builder().dataType("String").stringValue(action).build()
))
.build();
}).collect(Collectors.toList());
final SendMessageRequest request = SendMessageRequest.builder()
.queueUrl(queueUrl)
.messageBody("-")
.messageDeduplicationId(UUID.randomUUID().toString())
.messageGroupId(account.getNumber())
.messageAttributes(Map.of(
"id", MessageAttributeValue.builder().dataType("String").stringValue(account.getNumber()).build(),
"uuid", MessageAttributeValue.builder().dataType("String").stringValue(account.getUuid().toString()).build(),
"action", action.toMessageAttributeValue()
))
.build();
final SendMessageBatchRequest sendMessageBatchRequest = SendMessageBatchRequest.builder()
.queueUrl(queueUrl)
.entries(entries)
.build();
try (final Timer.Context ignored = sendMessageBatchTimer.time()) {
sqs.sendMessageBatch(sendMessageBatchRequest);
} catch (SdkServiceException ex) {
serviceErrorMeter.mark();
logger.warn("sqs service error: ", ex);
} catch (SdkClientException ex) {
clientErrorMeter.mark();
logger.warn("sqs client error: ", ex);
} catch (Throwable t) {
logger.warn("sqs unexpected error: ", t);
sqs.sendMessage(request).whenComplete((response, cause) -> {
try {
if (cause instanceof SdkServiceException) {
serviceErrorMeter.mark();
logger.warn("sqs service error", cause);
} else if (cause instanceof SdkClientException) {
clientErrorMeter.mark();
logger.warn("sqs client error", cause);
} else if (cause != null) {
logger.warn("sqs unexpected error", cause);
}
} finally {
timerContext.close();
}
}
});
}
}
}