diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index 876b247e81..29442e9d4d 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -136,7 +136,8 @@ public class ApplicationDependencyProvider implements ApplicationDependencies.Pr provideGroupsV2Operations(signalServiceConfiguration).getProfileOperations(), SignalExecutors.newCachedBoundedExecutor("signal-messages", ThreadUtil.PRIORITY_IMPORTANT_BACKGROUND_THREAD, 1, 16, 30), ByteUnit.KILOBYTES.toBytes(256), - FeatureFlags.okHttpAutomaticRetry()); + FeatureFlags.okHttpAutomaticRetry(), + FeatureFlags.useRxMessageSending()); } @Override diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/FeatureFlags.java b/app/src/main/java/org/thoughtcrime/securesms/util/FeatureFlags.java index fbe2f5148a..ffe0c4fd5c 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/util/FeatureFlags.java +++ b/app/src/main/java/org/thoughtcrime/securesms/util/FeatureFlags.java @@ -124,6 +124,7 @@ public final class FeatureFlags { private static final String RETRY_RECEIPT_MAX_COUNT_RESET_AGE = "android.retryReceipt.maxCountResetAge"; private static final String PREKEY_FORCE_REFRESH_INTERVAL = "android.prekeyForceRefreshInterval"; private static final String CDSI_LIBSIGNAL_NET = "android.cds.libsignal"; + private static final String RX_MESSAGE_SEND = "android.rxMessageSend"; /** * We will only store remote values for flags in this set. If you want a flag to be controllable @@ -200,7 +201,8 @@ public final class FeatureFlags { RETRY_RECEIPT_MAX_COUNT, RETRY_RECEIPT_MAX_COUNT_RESET_AGE, PREKEY_FORCE_REFRESH_INTERVAL, - CDSI_LIBSIGNAL_NET + CDSI_LIBSIGNAL_NET, + RX_MESSAGE_SEND ); @VisibleForTesting @@ -274,7 +276,8 @@ public final class FeatureFlags { RETRY_RECEIPT_MAX_COUNT, RETRY_RECEIPT_MAX_COUNT_RESET_AGE, PREKEY_FORCE_REFRESH_INTERVAL, - CDSI_LIBSIGNAL_NET + CDSI_LIBSIGNAL_NET, + RX_MESSAGE_SEND ); /** @@ -714,6 +717,11 @@ public final class FeatureFlags { return getBoolean(CDSI_LIBSIGNAL_NET, false); } + /** Use Rx threading model to do sends. */ + public static boolean useRxMessageSending() { + return getBoolean(RX_MESSAGE_SEND, false); + } + /** Only for rendering debug info. */ public static synchronized @NonNull Map getMemoryValues() { return new TreeMap<>(REMOTE_VALUES); diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/CancelationException.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/CancelationException.java index 221433ae9a..cd1692f092 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/CancelationException.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/CancelationException.java @@ -3,4 +3,10 @@ package org.whispersystems.signalservice.api; import java.io.IOException; public class CancelationException extends IOException { + public CancelationException() { + } + + public CancelationException(Throwable cause) { + super(cause); + } } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java index 6b2d271f31..411a449cd3 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java @@ -5,6 +5,7 @@ */ package org.whispersystems.signalservice.api; +import org.signal.core.util.Base64; import org.signal.libsignal.metadata.certificate.SenderCertificate; import org.signal.libsignal.protocol.IdentityKeyPair; import org.signal.libsignal.protocol.InvalidKeyException; @@ -85,6 +86,7 @@ import org.whispersystems.signalservice.api.util.Uint64RangeException; import org.whispersystems.signalservice.api.util.Uint64Util; import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException; +import org.whispersystems.signalservice.internal.ServiceResponse; import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; import org.whispersystems.signalservice.internal.crypto.AttachmentDigest; import org.whispersystems.signalservice.internal.crypto.PaddingInputStream; @@ -128,7 +130,6 @@ import org.whispersystems.signalservice.internal.push.http.PartialSendBatchCompl import org.whispersystems.signalservice.internal.push.http.PartialSendCompleteListener; import org.whispersystems.signalservice.internal.push.http.ResumableUploadSpec; import org.whispersystems.signalservice.internal.util.Util; -import org.signal.core.util.Base64; import org.whispersystems.util.ByteArrayUtil; import java.io.IOException; @@ -151,7 +152,13 @@ import java.util.concurrent.Future; import java.util.stream.Collectors; import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.core.Scheduler; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import kotlin.Unit; import okio.ByteString; /** @@ -179,6 +186,7 @@ public class SignalServiceMessageSender { private final ExecutorService executor; private final long maxEnvelopeSize; + private final boolean useRxMessageSend; public SignalServiceMessageSender(SignalServiceConfiguration urls, CredentialsProvider credentialsProvider, @@ -190,7 +198,8 @@ public class SignalServiceMessageSender { ClientZkProfileOperations clientZkProfileOperations, ExecutorService executor, long maxEnvelopeSize, - boolean automaticNetworkRetry) + boolean automaticNetworkRetry, + boolean useRxMessageSend) { this.socket = new PushServiceSocket(urls, credentialsProvider, signalAgent, clientZkProfileOperations, automaticNetworkRetry); this.aciStore = store.aci(); @@ -204,6 +213,7 @@ public class SignalServiceMessageSender { this.executor = executor != null ? executor : Executors.newSingleThreadExecutor(); this.maxEnvelopeSize = maxEnvelopeSize; this.localPniIdentity = store.pni().getIdentityKeyPair(); + this.useRxMessageSend = useRxMessageSend; } /** @@ -514,6 +524,7 @@ public class SignalServiceMessageSender { long timestamp = System.currentTimeMillis(); Log.d(TAG, "[" + timestamp + "] Sending SKDM to " + recipients.size() + " recipients for DistributionId " + distributionId); + return sendMessage(recipients, getTargetUnidentifiedAccess(unidentifiedAccess), timestamp, envelopeContent, false, null, null, null, urgent, story); } @@ -1912,6 +1923,10 @@ public class SignalServiceMessageSender { boolean story) throws IOException { + if (useRxMessageSend) { + return sendMessageRx(recipients, unidentifiedAccess, timestamp, content, online, partialListener, cancelationSignal, sendEvents, urgent, story); + } + Log.d(TAG, "[" + timestamp + "] Sending to " + recipients.size() + " recipients."); enforceMaxContentSize(content); @@ -2088,6 +2103,303 @@ public class SignalServiceMessageSender { throw new IOException("Failed to resolve conflicts after " + RETRY_COUNT + " attempts!"); } + /** + * Send a message to multiple recipients. + * + * @return An unordered list of a {@link SendMessageResult} for each send. + * @throws IOException - Unknown failure or a failure not representable by an unsuccessful {@code SendMessageResult}. + */ + private List sendMessageRx(List recipients, + List> unidentifiedAccess, + long timestamp, + EnvelopeContent content, + boolean online, + PartialSendCompleteListener partialListener, + CancelationSignal cancelationSignal, + @Nullable SendEvents sendEvents, + boolean urgent, + boolean story) + throws IOException + { + Log.d(TAG, "[" + timestamp + "] Sending to " + recipients.size() + " recipients via Rx."); + enforceMaxContentSize(content); + + long startTime = System.currentTimeMillis(); + List> singleResults = new LinkedList<>(); + Iterator recipientIterator = recipients.iterator(); + Iterator> unidentifiedAccessIterator = unidentifiedAccess.iterator(); + + while (recipientIterator.hasNext()) { + SignalServiceAddress recipient = recipientIterator.next(); + Optional access = unidentifiedAccessIterator.next(); + singleResults.add(sendMessageRx(recipient, access, timestamp, content, online, cancelationSignal, sendEvents, urgent, story, 0).toObservable()); + } + + List results; + try { + results = Observable.mergeDelayError(singleResults, Integer.MAX_VALUE, 1) + .observeOn(Schedulers.io(), true) + .scan(new ArrayList(singleResults.size()), (state, result) -> { + state.add(result); + if (partialListener != null) { + partialListener.onPartialSendComplete(result); + } + return state; + }) + .lastOrError() + .blockingGet(); + } catch (RuntimeException e) { + Throwable cause = e.getCause(); + if (cause instanceof IOException) { + throw (IOException) cause; + } else if (cause instanceof InterruptedException) { + throw new CancelationException(e); + } else { + throw e; + } + } + + double sendsForAverage = 0; + for (SendMessageResult result : results) { + if (result.getSuccess() != null && result.getSuccess().getDuration() != -1) { + sendsForAverage++; + } + } + + double average = 0; + if (sendsForAverage > 0) { + for (SendMessageResult result : results) { + if (result.getSuccess() != null && result.getSuccess().getDuration() != -1) { + average += result.getSuccess().getDuration() / sendsForAverage; + } + } + } + + Log.d(TAG, "[" + timestamp + "] Completed send to " + recipients.size() + " recipients in " + (System.currentTimeMillis() - startTime) + " ms, with an average time of " + Math.round(average) + " ms per send via Rx."); + return results; + } + + /** + * Sends a message over the appropriate websocket, falls back to REST when unavailable, and emits a {@link SendMessageResult} for most business + * logic error cases. + *

+ * Uses a "feature" or Rx where if no {@link Single#subscribeOn(Scheduler)} operator is used, the subscribing thread is used to perform the + * initial work. This allows the calling thread to do the starting of the send work (encryption and putting it on the wire) and can be called + * multiple times in a loop, but allow the network transit/processing/error retry logic to run on a background thread. + *

+ * Processing happens on the background thread via an {@link Single#observeOn(Scheduler)} call after the encrypt and send. Error + * handling operators are added after the observe so they will also run on a background thread. Retry logic during error handling + * is a recursive call, so error handling thread becomes the method "calling and subscribing" thread so all retries will perform the + * encryption/send/processing on that background thread. + * + * @return A single that wraps success and business failures as a {@link SendMessageResult} but will still emit unhandled/unrecoverable + * errors via {@code onError} + */ + private Single sendMessageRx(SignalServiceAddress recipient, + final Optional unidentifiedAccess, + long timestamp, + EnvelopeContent content, + boolean online, + CancelationSignal cancelationSignal, + @Nullable SendEvents sendEvents, + boolean urgent, + boolean story, + int retryCount) + { + long startTime = System.currentTimeMillis(); + enforceMaxContentSize(content); + + Single messagesSingle = Single.fromCallable(() -> { + OutgoingPushMessageList messages = getEncryptedMessages(recipient, unidentifiedAccess, timestamp, content, online, urgent, story); + + if (retryCount == 0 && sendEvents != null) { + sendEvents.onMessageEncrypted(); + } + + if (content.getContent().isPresent() && content.getContent().get().syncMessage != null && content.getContent().get().syncMessage.sent != null) { + Log.d(TAG, "[sendMessage][" + timestamp + "] Sending a sent sync message to devices: " + messages.getDevices() + " via Rx"); + } else if (content.getContent().isPresent() && content.getContent().get().senderKeyDistributionMessage != null) { + Log.d(TAG, "[sendMessage][" + timestamp + "] Sending a SKDM to " + messages.getDestination() + " for devices: " + messages.getDevices() + (content.getContent().get().dataMessage != null ? " (it's piggy-backing on a DataMessage) via Rx" : " via Rx")); + } + + return messages; + }); + + Single sendWithFallback = messagesSingle + .flatMap(messages -> { + if (cancelationSignal != null && cancelationSignal.isCanceled()) { + return Single.error(new CancelationException()); + } + + return messagingService.send(messages, unidentifiedAccess, story) + .map(r -> new kotlin.Pair<>(messages, r)); + }) + .observeOn(Schedulers.io()) + .flatMap(pair -> { + final OutgoingPushMessageList messages = pair.getFirst(); + final ServiceResponse serviceResponse = pair.getSecond(); + + if (serviceResponse.getResult().isPresent()) { + SendMessageResponse response = serviceResponse.getResult().get(); + SendMessageResult result = SendMessageResult.success( + recipient, + messages.getDevices(), + response.sentUnidentified(), + response.getNeedsSync() || aciStore.isMultiDevice(), + System.currentTimeMillis() - startTime, + content.getContent() + ); + return Single.just(result); + } else { + if (cancelationSignal != null && cancelationSignal.isCanceled()) { + return Single.error(new CancelationException()); + } + + //noinspection OptionalGetWithoutIsPresent + Throwable throwable = serviceResponse.getApplicationError().or(serviceResponse::getExecutionError).get(); + + if (throwable instanceof InvalidUnidentifiedAccessHeaderException || + throwable instanceof UnregisteredUserException || + throwable instanceof MismatchedDevicesException || + throwable instanceof StaleDevicesException) + { + // Non-technical failures shouldn't be retried with socket + return Single.error(throwable); + } else if (throwable instanceof WebSocketUnavailableException) { + Log.i(TAG, "[sendMessage][" + timestamp + "] " + (unidentifiedAccess.isPresent() ? "Unidentified " : "") + "pipe unavailable, falling back... (" + throwable.getClass().getSimpleName() + ": " + throwable.getMessage() + ")"); + } else if (throwable instanceof IOException) { + Throwable cause = throwable.getCause() != null ? throwable.getCause() : throwable; + Log.w(TAG, "[sendMessage][" + timestamp + "] " + (unidentifiedAccess.isPresent() ? "Unidentified " : "") + "pipe failed, falling back... (" + cause.getClass().getSimpleName() + ": " + cause.getMessage() + ")"); + } + + return Single.fromCallable(() -> { + SendMessageResponse response = socket.sendMessage(messages, unidentifiedAccess, story); + return SendMessageResult.success( + recipient, + messages.getDevices(), + response.sentUnidentified(), + response.getNeedsSync() || aciStore.isMultiDevice(), + System.currentTimeMillis() - startTime, + content.getContent() + ); + }).subscribeOn(Schedulers.io()); + } + }); + + return sendWithFallback.onErrorResumeNext(t -> { + if (cancelationSignal != null && cancelationSignal.isCanceled()) { + return Single.error(new CancelationException()); + } + + if (retryCount >= RETRY_COUNT) { + return Single.error(t); + } + + if (t instanceof InvalidKeyException) { + Log.w(TAG, t); + return sendMessageRx( + recipient, + Optional.empty(), + timestamp, + content, + online, + cancelationSignal, + sendEvents, + urgent, + story, + retryCount + 1 + ); + } else if (t instanceof AuthorizationFailedException) { + if (unidentifiedAccess.isPresent()) { + Log.w(TAG, "Got an AuthorizationFailedException when trying to send using sealed sender. Falling back."); + return sendMessageRx( + recipient, + Optional.empty(), + timestamp, + content, + online, + cancelationSignal, + sendEvents, + urgent, + story, + retryCount + 1 + ); + } else { + Log.w(TAG, "Got an AuthorizationFailedException without using sealed sender!", t); + return Single.error(t); + } + } else if (t instanceof MismatchedDevicesException) { + MismatchedDevicesException mde = (MismatchedDevicesException) t; + Log.w(TAG, "[sendMessage][" + timestamp + "] Handling mismatched devices. (" + mde.getMessage() + ")"); + + return Single.fromCallable(() -> { + handleMismatchedDevices(socket, recipient, mde.getMismatchedDevices()); + return Unit.INSTANCE; + }) + .flatMap(unused -> sendMessageRx( + recipient, + unidentifiedAccess, + timestamp, + content, + online, + cancelationSignal, + sendEvents, + urgent, + story, + retryCount + 1) + ); + } else if (t instanceof StaleDevicesException) { + StaleDevicesException ste = (StaleDevicesException) t; + Log.w(TAG, "[sendMessage][" + timestamp + "] Handling stale devices. (" + ste.getMessage() + ")"); + + return Single.fromCallable(() -> { + handleStaleDevices(recipient, ste.getStaleDevices()); + return Unit.INSTANCE; + }) + .flatMap(unused -> sendMessageRx( + recipient, + unidentifiedAccess, + timestamp, + content, + online, + cancelationSignal, + sendEvents, + urgent, + story, + retryCount + 1) + ); + } + + return Single.error(t); + }).onErrorResumeNext(t -> { + if (t instanceof UntrustedIdentityException) { + Log.w(TAG, "[" + timestamp + "] Hit identity mismatch: " + recipient.getIdentifier(), t); + return Single.just(SendMessageResult.identityFailure(recipient, ((UntrustedIdentityException) t).getIdentityKey())); + } else if (t instanceof UnregisteredUserException) { + Log.w(TAG, "[" + timestamp + "] Hit unregistered user: " + recipient.getIdentifier()); + return Single.just(SendMessageResult.unregisteredFailure(recipient)); + } else if (t instanceof PushNetworkException) { + Log.w(TAG, "[" + timestamp + "] Hit network failure: " + recipient.getIdentifier(), t); + return Single.just(SendMessageResult.networkFailure(recipient)); + } else if (t instanceof ServerRejectedException) { + Log.w(TAG, "[" + timestamp + "] Hit server rejection: " + recipient.getIdentifier(), t); + return Single.error(t); + } else if (t instanceof ProofRequiredException) { + Log.w(TAG, "[" + timestamp + "] Hit proof required: " + recipient.getIdentifier(), t); + return Single.just(SendMessageResult.proofRequiredFailure(recipient, (ProofRequiredException) t)); + } else if (t instanceof RateLimitException) { + Log.w(TAG, "[" + timestamp + "] Hit rate limit: " + recipient.getIdentifier(), t); + return Single.just(SendMessageResult.rateLimitFailure(recipient, (RateLimitException) t)); + } else if (t instanceof InvalidPreKeyException) { + Log.w(TAG, "[" + timestamp + "] Hit invalid prekey: " + recipient.getIdentifier(), t); + return Single.just(SendMessageResult.invalidPreKeyFailure(recipient)); + } else { + Log.w(TAG, "[" + timestamp + "] Hit unknown exception: " + recipient.getIdentifier(), t); + return Single.error(new IOException(t)); + } + }); + } + /** * Will send a message using sender keys to all of the specified recipients. It is assumed that * all of the recipients have UUIDs.