Improve local fanout send performance.

This commit is contained in:
Cody Henthorne
2024-03-05 09:30:48 -05:00
committed by Alex Hart
parent 9f197b12ed
commit 619038f27d
4 changed files with 332 additions and 5 deletions

View File

@@ -136,7 +136,8 @@ public class ApplicationDependencyProvider implements ApplicationDependencies.Pr
provideGroupsV2Operations(signalServiceConfiguration).getProfileOperations(),
SignalExecutors.newCachedBoundedExecutor("signal-messages", ThreadUtil.PRIORITY_IMPORTANT_BACKGROUND_THREAD, 1, 16, 30),
ByteUnit.KILOBYTES.toBytes(256),
FeatureFlags.okHttpAutomaticRetry());
FeatureFlags.okHttpAutomaticRetry(),
FeatureFlags.useRxMessageSending());
}
@Override

View File

@@ -124,6 +124,7 @@ public final class FeatureFlags {
private static final String RETRY_RECEIPT_MAX_COUNT_RESET_AGE = "android.retryReceipt.maxCountResetAge";
private static final String PREKEY_FORCE_REFRESH_INTERVAL = "android.prekeyForceRefreshInterval";
private static final String CDSI_LIBSIGNAL_NET = "android.cds.libsignal";
private static final String RX_MESSAGE_SEND = "android.rxMessageSend";
/**
* We will only store remote values for flags in this set. If you want a flag to be controllable
@@ -200,7 +201,8 @@ public final class FeatureFlags {
RETRY_RECEIPT_MAX_COUNT,
RETRY_RECEIPT_MAX_COUNT_RESET_AGE,
PREKEY_FORCE_REFRESH_INTERVAL,
CDSI_LIBSIGNAL_NET
CDSI_LIBSIGNAL_NET,
RX_MESSAGE_SEND
);
@VisibleForTesting
@@ -274,7 +276,8 @@ public final class FeatureFlags {
RETRY_RECEIPT_MAX_COUNT,
RETRY_RECEIPT_MAX_COUNT_RESET_AGE,
PREKEY_FORCE_REFRESH_INTERVAL,
CDSI_LIBSIGNAL_NET
CDSI_LIBSIGNAL_NET,
RX_MESSAGE_SEND
);
/**
@@ -714,6 +717,11 @@ public final class FeatureFlags {
return getBoolean(CDSI_LIBSIGNAL_NET, false);
}
/** Use Rx threading model to do sends. */
public static boolean useRxMessageSending() {
return getBoolean(RX_MESSAGE_SEND, false);
}
/** Only for rendering debug info. */
public static synchronized @NonNull Map<String, Object> getMemoryValues() {
return new TreeMap<>(REMOTE_VALUES);

View File

@@ -3,4 +3,10 @@ package org.whispersystems.signalservice.api;
import java.io.IOException;
public class CancelationException extends IOException {
public CancelationException() {
}
public CancelationException(Throwable cause) {
super(cause);
}
}

View File

@@ -5,6 +5,7 @@
*/
package org.whispersystems.signalservice.api;
import org.signal.core.util.Base64;
import org.signal.libsignal.metadata.certificate.SenderCertificate;
import org.signal.libsignal.protocol.IdentityKeyPair;
import org.signal.libsignal.protocol.InvalidKeyException;
@@ -85,6 +86,7 @@ import org.whispersystems.signalservice.api.util.Uint64RangeException;
import org.whispersystems.signalservice.api.util.Uint64Util;
import org.whispersystems.signalservice.api.util.UuidUtil;
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException;
import org.whispersystems.signalservice.internal.ServiceResponse;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import org.whispersystems.signalservice.internal.crypto.AttachmentDigest;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
@@ -128,7 +130,6 @@ import org.whispersystems.signalservice.internal.push.http.PartialSendBatchCompl
import org.whispersystems.signalservice.internal.push.http.PartialSendCompleteListener;
import org.whispersystems.signalservice.internal.push.http.ResumableUploadSpec;
import org.whispersystems.signalservice.internal.util.Util;
import org.signal.core.util.Base64;
import org.whispersystems.util.ByteArrayUtil;
import java.io.IOException;
@@ -151,7 +152,13 @@ import java.util.concurrent.Future;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Scheduler;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.schedulers.Schedulers;
import kotlin.Unit;
import okio.ByteString;
/**
@@ -179,6 +186,7 @@ public class SignalServiceMessageSender {
private final ExecutorService executor;
private final long maxEnvelopeSize;
private final boolean useRxMessageSend;
public SignalServiceMessageSender(SignalServiceConfiguration urls,
CredentialsProvider credentialsProvider,
@@ -190,7 +198,8 @@ public class SignalServiceMessageSender {
ClientZkProfileOperations clientZkProfileOperations,
ExecutorService executor,
long maxEnvelopeSize,
boolean automaticNetworkRetry)
boolean automaticNetworkRetry,
boolean useRxMessageSend)
{
this.socket = new PushServiceSocket(urls, credentialsProvider, signalAgent, clientZkProfileOperations, automaticNetworkRetry);
this.aciStore = store.aci();
@@ -204,6 +213,7 @@ public class SignalServiceMessageSender {
this.executor = executor != null ? executor : Executors.newSingleThreadExecutor();
this.maxEnvelopeSize = maxEnvelopeSize;
this.localPniIdentity = store.pni().getIdentityKeyPair();
this.useRxMessageSend = useRxMessageSend;
}
/**
@@ -514,6 +524,7 @@ public class SignalServiceMessageSender {
long timestamp = System.currentTimeMillis();
Log.d(TAG, "[" + timestamp + "] Sending SKDM to " + recipients.size() + " recipients for DistributionId " + distributionId);
return sendMessage(recipients, getTargetUnidentifiedAccess(unidentifiedAccess), timestamp, envelopeContent, false, null, null, null, urgent, story);
}
@@ -1912,6 +1923,10 @@ public class SignalServiceMessageSender {
boolean story)
throws IOException
{
if (useRxMessageSend) {
return sendMessageRx(recipients, unidentifiedAccess, timestamp, content, online, partialListener, cancelationSignal, sendEvents, urgent, story);
}
Log.d(TAG, "[" + timestamp + "] Sending to " + recipients.size() + " recipients.");
enforceMaxContentSize(content);
@@ -2088,6 +2103,303 @@ public class SignalServiceMessageSender {
throw new IOException("Failed to resolve conflicts after " + RETRY_COUNT + " attempts!");
}
/**
* Send a message to multiple recipients.
*
* @return An unordered list of a {@link SendMessageResult} for each send.
* @throws IOException - Unknown failure or a failure not representable by an unsuccessful {@code SendMessageResult}.
*/
private List<SendMessageResult> sendMessageRx(List<SignalServiceAddress> recipients,
List<Optional<UnidentifiedAccess>> unidentifiedAccess,
long timestamp,
EnvelopeContent content,
boolean online,
PartialSendCompleteListener partialListener,
CancelationSignal cancelationSignal,
@Nullable SendEvents sendEvents,
boolean urgent,
boolean story)
throws IOException
{
Log.d(TAG, "[" + timestamp + "] Sending to " + recipients.size() + " recipients via Rx.");
enforceMaxContentSize(content);
long startTime = System.currentTimeMillis();
List<Observable<SendMessageResult>> singleResults = new LinkedList<>();
Iterator<SignalServiceAddress> recipientIterator = recipients.iterator();
Iterator<Optional<UnidentifiedAccess>> unidentifiedAccessIterator = unidentifiedAccess.iterator();
while (recipientIterator.hasNext()) {
SignalServiceAddress recipient = recipientIterator.next();
Optional<UnidentifiedAccess> access = unidentifiedAccessIterator.next();
singleResults.add(sendMessageRx(recipient, access, timestamp, content, online, cancelationSignal, sendEvents, urgent, story, 0).toObservable());
}
List<SendMessageResult> results;
try {
results = Observable.mergeDelayError(singleResults, Integer.MAX_VALUE, 1)
.observeOn(Schedulers.io(), true)
.scan(new ArrayList<SendMessageResult>(singleResults.size()), (state, result) -> {
state.add(result);
if (partialListener != null) {
partialListener.onPartialSendComplete(result);
}
return state;
})
.lastOrError()
.blockingGet();
} catch (RuntimeException e) {
Throwable cause = e.getCause();
if (cause instanceof IOException) {
throw (IOException) cause;
} else if (cause instanceof InterruptedException) {
throw new CancelationException(e);
} else {
throw e;
}
}
double sendsForAverage = 0;
for (SendMessageResult result : results) {
if (result.getSuccess() != null && result.getSuccess().getDuration() != -1) {
sendsForAverage++;
}
}
double average = 0;
if (sendsForAverage > 0) {
for (SendMessageResult result : results) {
if (result.getSuccess() != null && result.getSuccess().getDuration() != -1) {
average += result.getSuccess().getDuration() / sendsForAverage;
}
}
}
Log.d(TAG, "[" + timestamp + "] Completed send to " + recipients.size() + " recipients in " + (System.currentTimeMillis() - startTime) + " ms, with an average time of " + Math.round(average) + " ms per send via Rx.");
return results;
}
/**
* Sends a message over the appropriate websocket, falls back to REST when unavailable, and emits a {@link SendMessageResult} for most business
* logic error cases.
* <p>
* Uses a "feature" or Rx where if no {@link Single#subscribeOn(Scheduler)} operator is used, the subscribing thread is used to perform the
* initial work. This allows the calling thread to do the starting of the send work (encryption and putting it on the wire) and can be called
* multiple times in a loop, but allow the network transit/processing/error retry logic to run on a background thread.
* <p>
* Processing happens on the background thread via an {@link Single#observeOn(Scheduler)} call after the encrypt and send. Error
* handling operators are added after the observe so they will also run on a background thread. Retry logic during error handling
* is a recursive call, so error handling thread becomes the method "calling and subscribing" thread so all retries will perform the
* encryption/send/processing on that background thread.
*
* @return A single that wraps success and business failures as a {@link SendMessageResult} but will still emit unhandled/unrecoverable
* errors via {@code onError}
*/
private Single<SendMessageResult> sendMessageRx(SignalServiceAddress recipient,
final Optional<UnidentifiedAccess> unidentifiedAccess,
long timestamp,
EnvelopeContent content,
boolean online,
CancelationSignal cancelationSignal,
@Nullable SendEvents sendEvents,
boolean urgent,
boolean story,
int retryCount)
{
long startTime = System.currentTimeMillis();
enforceMaxContentSize(content);
Single<OutgoingPushMessageList> messagesSingle = Single.fromCallable(() -> {
OutgoingPushMessageList messages = getEncryptedMessages(recipient, unidentifiedAccess, timestamp, content, online, urgent, story);
if (retryCount == 0 && sendEvents != null) {
sendEvents.onMessageEncrypted();
}
if (content.getContent().isPresent() && content.getContent().get().syncMessage != null && content.getContent().get().syncMessage.sent != null) {
Log.d(TAG, "[sendMessage][" + timestamp + "] Sending a sent sync message to devices: " + messages.getDevices() + " via Rx");
} else if (content.getContent().isPresent() && content.getContent().get().senderKeyDistributionMessage != null) {
Log.d(TAG, "[sendMessage][" + timestamp + "] Sending a SKDM to " + messages.getDestination() + " for devices: " + messages.getDevices() + (content.getContent().get().dataMessage != null ? " (it's piggy-backing on a DataMessage) via Rx" : " via Rx"));
}
return messages;
});
Single<SendMessageResult> sendWithFallback = messagesSingle
.flatMap(messages -> {
if (cancelationSignal != null && cancelationSignal.isCanceled()) {
return Single.error(new CancelationException());
}
return messagingService.send(messages, unidentifiedAccess, story)
.map(r -> new kotlin.Pair<>(messages, r));
})
.observeOn(Schedulers.io())
.flatMap(pair -> {
final OutgoingPushMessageList messages = pair.getFirst();
final ServiceResponse<SendMessageResponse> serviceResponse = pair.getSecond();
if (serviceResponse.getResult().isPresent()) {
SendMessageResponse response = serviceResponse.getResult().get();
SendMessageResult result = SendMessageResult.success(
recipient,
messages.getDevices(),
response.sentUnidentified(),
response.getNeedsSync() || aciStore.isMultiDevice(),
System.currentTimeMillis() - startTime,
content.getContent()
);
return Single.just(result);
} else {
if (cancelationSignal != null && cancelationSignal.isCanceled()) {
return Single.error(new CancelationException());
}
//noinspection OptionalGetWithoutIsPresent
Throwable throwable = serviceResponse.getApplicationError().or(serviceResponse::getExecutionError).get();
if (throwable instanceof InvalidUnidentifiedAccessHeaderException ||
throwable instanceof UnregisteredUserException ||
throwable instanceof MismatchedDevicesException ||
throwable instanceof StaleDevicesException)
{
// Non-technical failures shouldn't be retried with socket
return Single.error(throwable);
} else if (throwable instanceof WebSocketUnavailableException) {
Log.i(TAG, "[sendMessage][" + timestamp + "] " + (unidentifiedAccess.isPresent() ? "Unidentified " : "") + "pipe unavailable, falling back... (" + throwable.getClass().getSimpleName() + ": " + throwable.getMessage() + ")");
} else if (throwable instanceof IOException) {
Throwable cause = throwable.getCause() != null ? throwable.getCause() : throwable;
Log.w(TAG, "[sendMessage][" + timestamp + "] " + (unidentifiedAccess.isPresent() ? "Unidentified " : "") + "pipe failed, falling back... (" + cause.getClass().getSimpleName() + ": " + cause.getMessage() + ")");
}
return Single.fromCallable(() -> {
SendMessageResponse response = socket.sendMessage(messages, unidentifiedAccess, story);
return SendMessageResult.success(
recipient,
messages.getDevices(),
response.sentUnidentified(),
response.getNeedsSync() || aciStore.isMultiDevice(),
System.currentTimeMillis() - startTime,
content.getContent()
);
}).subscribeOn(Schedulers.io());
}
});
return sendWithFallback.onErrorResumeNext(t -> {
if (cancelationSignal != null && cancelationSignal.isCanceled()) {
return Single.error(new CancelationException());
}
if (retryCount >= RETRY_COUNT) {
return Single.error(t);
}
if (t instanceof InvalidKeyException) {
Log.w(TAG, t);
return sendMessageRx(
recipient,
Optional.empty(),
timestamp,
content,
online,
cancelationSignal,
sendEvents,
urgent,
story,
retryCount + 1
);
} else if (t instanceof AuthorizationFailedException) {
if (unidentifiedAccess.isPresent()) {
Log.w(TAG, "Got an AuthorizationFailedException when trying to send using sealed sender. Falling back.");
return sendMessageRx(
recipient,
Optional.empty(),
timestamp,
content,
online,
cancelationSignal,
sendEvents,
urgent,
story,
retryCount + 1
);
} else {
Log.w(TAG, "Got an AuthorizationFailedException without using sealed sender!", t);
return Single.error(t);
}
} else if (t instanceof MismatchedDevicesException) {
MismatchedDevicesException mde = (MismatchedDevicesException) t;
Log.w(TAG, "[sendMessage][" + timestamp + "] Handling mismatched devices. (" + mde.getMessage() + ")");
return Single.fromCallable(() -> {
handleMismatchedDevices(socket, recipient, mde.getMismatchedDevices());
return Unit.INSTANCE;
})
.flatMap(unused -> sendMessageRx(
recipient,
unidentifiedAccess,
timestamp,
content,
online,
cancelationSignal,
sendEvents,
urgent,
story,
retryCount + 1)
);
} else if (t instanceof StaleDevicesException) {
StaleDevicesException ste = (StaleDevicesException) t;
Log.w(TAG, "[sendMessage][" + timestamp + "] Handling stale devices. (" + ste.getMessage() + ")");
return Single.fromCallable(() -> {
handleStaleDevices(recipient, ste.getStaleDevices());
return Unit.INSTANCE;
})
.flatMap(unused -> sendMessageRx(
recipient,
unidentifiedAccess,
timestamp,
content,
online,
cancelationSignal,
sendEvents,
urgent,
story,
retryCount + 1)
);
}
return Single.error(t);
}).onErrorResumeNext(t -> {
if (t instanceof UntrustedIdentityException) {
Log.w(TAG, "[" + timestamp + "] Hit identity mismatch: " + recipient.getIdentifier(), t);
return Single.just(SendMessageResult.identityFailure(recipient, ((UntrustedIdentityException) t).getIdentityKey()));
} else if (t instanceof UnregisteredUserException) {
Log.w(TAG, "[" + timestamp + "] Hit unregistered user: " + recipient.getIdentifier());
return Single.just(SendMessageResult.unregisteredFailure(recipient));
} else if (t instanceof PushNetworkException) {
Log.w(TAG, "[" + timestamp + "] Hit network failure: " + recipient.getIdentifier(), t);
return Single.just(SendMessageResult.networkFailure(recipient));
} else if (t instanceof ServerRejectedException) {
Log.w(TAG, "[" + timestamp + "] Hit server rejection: " + recipient.getIdentifier(), t);
return Single.error(t);
} else if (t instanceof ProofRequiredException) {
Log.w(TAG, "[" + timestamp + "] Hit proof required: " + recipient.getIdentifier(), t);
return Single.just(SendMessageResult.proofRequiredFailure(recipient, (ProofRequiredException) t));
} else if (t instanceof RateLimitException) {
Log.w(TAG, "[" + timestamp + "] Hit rate limit: " + recipient.getIdentifier(), t);
return Single.just(SendMessageResult.rateLimitFailure(recipient, (RateLimitException) t));
} else if (t instanceof InvalidPreKeyException) {
Log.w(TAG, "[" + timestamp + "] Hit invalid prekey: " + recipient.getIdentifier(), t);
return Single.just(SendMessageResult.invalidPreKeyFailure(recipient));
} else {
Log.w(TAG, "[" + timestamp + "] Hit unknown exception: " + recipient.getIdentifier(), t);
return Single.error(new IOException(t));
}
});
}
/**
* Will send a message using sender keys to all of the specified recipients. It is assumed that
* all of the recipients have UUIDs.