Perform message decryptions in batches.

This commit is contained in:
Greyson Parrelli
2023-03-09 17:05:00 -05:00
parent 04baa7925f
commit 894095414a
17 changed files with 772 additions and 69 deletions

View File

@@ -1,7 +1,10 @@
package org.whispersystems.signalservice.api;
import com.google.protobuf.InvalidProtocolBufferException;
import org.signal.libsignal.protocol.logging.Log;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
import org.whispersystems.signalservice.api.messages.EnvelopeResponse;
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState;
import org.whispersystems.signalservice.api.websocket.WebSocketFactory;
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException;
@@ -13,9 +16,10 @@ import org.whispersystems.signalservice.internal.websocket.WebsocketResponse;
import org.whispersystems.util.Base64;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Single;
@@ -218,61 +222,105 @@ public final class SignalWebSocket {
}
/**
* <p>
* A blocking call that reads a message off the pipe. When this call returns, if the callback indicates the
* message was successfully processed, then the message will be ack'ed on the serve and will not be retransmitted.
* <p>
* This will return true if there are more messages to be read from the websocket, or false if the websocket is empty.
* <p>
* You can specify a {@link MessageReceivedCallback} that will be called before the received message is acknowledged.
* This allows you to write the received message to durable storage before acknowledging receipt of it to the
* server.
* <p>
* Important: This will only return `false` once for each connection. That means if you get false call readMessage()
* again on the same instance, you will not get an immediate `false` return value, and instead will block until
* you get an actual message. This will, however, reset if connection breaks (if, for instance, you lose and regain network).
* The reads a batch of messages off of the websocket.
*
* @param timeout The timeout to wait for.
* @param callback A callback that will be called before the message receipt is acknowledged to the server.
* @return The message read (same as the message sent through the callback).
* Rather than just provide you the batch as a return value, it will invoke the provided callback with the
* batch as an argument. If you are able to successfully process them, this method will then ack all of the
* messages so that they won't be re-delivered in the future.
*
* The return value of this method is a boolean indicating whether or not there are more messages in the
* queue to be read (true if there's still more, or false if you've drained everything).
*
* However, this return value is only really useful the first time you read from the websocket. That's because
* the websocket will only ever let you know if it's drained *once* for any given connection. So if this method
* returns false, a subsequent call while using the same websocket connection will simply block until we either
* get a new message or hit the timeout.
*
* Concerning the requested batch size, it's worth noting that this is simply an upper bound. This method will
* not wait extra time until the batch has "filled up". Instead, it will wait for a single message, and then
* take any extra messages that are also available up until you've hit your batch size.
*/
@SuppressWarnings("DuplicateThrows")
public boolean readMessage(long timeout, MessageReceivedCallback callback)
public boolean readMessageBatch(long timeout, int batchSize, MessageReceivedCallback callback)
throws TimeoutException, WebSocketUnavailableException, IOException
{
while (true) {
WebSocketRequestMessage request = getWebSocket().readRequest(timeout);
WebSocketResponseMessage response = createWebSocketResponse(request);
List<EnvelopeResponse> responses = new ArrayList<>();
boolean hitEndOfQueue = false;
AtomicBoolean successfullyProcessed = new AtomicBoolean(false);
Optional<EnvelopeResponse> firstEnvelope = waitForSingleMessage(timeout);
try {
if (isSignalServiceEnvelope(request)) {
Optional<String> timestampHeader = findHeader(request);
long timestamp = 0;
if (firstEnvelope.isPresent()) {
responses.add(firstEnvelope.get());
} else {
hitEndOfQueue = true;
}
if (timestampHeader.isPresent()) {
try {
timestamp = Long.parseLong(timestampHeader.get());
} catch (NumberFormatException e) {
Log.w(TAG, "Failed to parse " + SERVER_DELIVERED_TIMESTAMP_HEADER);
}
if (!hitEndOfQueue) {
for (int i = 1; i < batchSize; i++) {
Optional<WebSocketRequestMessage> request = getWebSocket().readRequestIfAvailable();
if (request.isPresent()) {
if (isSignalServiceEnvelope(request.get())) {
responses.add(requestToEnvelopeResponse(request.get()));
} else if (isSocketEmptyRequest(request.get())) {
hitEndOfQueue = true;
break;
}
SignalServiceProtos.Envelope envelope = SignalServiceProtos.Envelope.parseFrom(request.getBody().toByteArray());
successfullyProcessed.set(callback.onMessage(envelope, timestamp));
return true;
} else if (isSocketEmptyRequest(request)) {
return false;
}
} finally {
if (successfullyProcessed.get()) {
getWebSocket().sendResponse(response);
} else {
break;
}
}
}
if (responses.size() > 0) {
boolean successfullyProcessed = false;
try {
successfullyProcessed = callback.onMessageBatch(responses);
} finally {
if (successfullyProcessed) {
for (EnvelopeResponse response : responses) {
getWebSocket().sendResponse(createWebSocketResponse(response.getWebsocketRequest()));
}
}
}
}
return !hitEndOfQueue;
}
@SuppressWarnings("DuplicateThrows")
private Optional<EnvelopeResponse> waitForSingleMessage(long timeout)
throws TimeoutException, WebSocketUnavailableException, IOException
{
while (true) {
WebSocketRequestMessage request = getWebSocket().readRequest(timeout);
if (isSignalServiceEnvelope(request)) {
return Optional.of(requestToEnvelopeResponse(request));
} else if (isSocketEmptyRequest(request)) {
return Optional.empty();
}
}
}
private static EnvelopeResponse requestToEnvelopeResponse(WebSocketRequestMessage request)
throws InvalidProtocolBufferException
{
Optional<String> timestampHeader = findHeader(request);
long timestamp = 0;
if (timestampHeader.isPresent()) {
try {
timestamp = Long.parseLong(timestampHeader.get());
} catch (NumberFormatException e) {
Log.w(TAG, "Failed to parse " + SERVER_DELIVERED_TIMESTAMP_HEADER);
}
}
SignalServiceProtos.Envelope envelope = SignalServiceProtos.Envelope.parseFrom(request.getBody().toByteArray());
return new EnvelopeResponse(envelope, timestamp, request);
}
private static boolean isSignalServiceEnvelope(WebSocketRequestMessage message) {
@@ -323,6 +371,6 @@ public final class SignalWebSocket {
public interface MessageReceivedCallback {
/** True if you successfully processed the message, otherwise false. **/
boolean onMessage(SignalServiceProtos.Envelope envelope, long serverDeliveredTimestamp);
boolean onMessageBatch(List<EnvelopeResponse> envelopeResponses);
}
}

View File

@@ -0,0 +1,13 @@
package org.whispersystems.signalservice.api.messages
import org.whispersystems.signalservice.internal.push.SignalServiceProtos.Envelope
import org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketRequestMessage
/**
* Represents an envelope off the wire, paired with the metadata needed to process it.
*/
class EnvelopeResponse(
val envelope: Envelope,
val serverDeliveredTimestamp: Long,
val websocketRequest: WebSocketRequestMessage
)

View File

@@ -193,6 +193,14 @@ public class WebSocketConnection extends WebSocketListener {
notifyAll();
}
public synchronized Optional<WebSocketRequestMessage> readRequestIfAvailable() {
if (incomingRequests.size() > 0) {
return Optional.of(incomingRequests.removeFirst());
} else {
return Optional.empty();
}
}
public synchronized WebSocketRequestMessage readRequest(long timeoutMillis)
throws TimeoutException, IOException
{