From bca220594535dc3a953be86dba8e000459706e73 Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Mon, 30 Aug 2021 15:07:03 -0400 Subject: [PATCH] Add measurements, improve MSL insert. --- .../securesms/AppCapabilities.java | 1 - .../storage/TextSecureSessionStore.java | 5 + .../database/MessageSendLogDatabase.kt | 15 +-- .../securesms/database/SessionDatabase.java | 43 ++++--- .../securesms/jobs/PushGroupSendJob.java | 7 +- .../securesms/messages/GroupSendUtil.java | 83 ++++++++++++- .../securesms/util/SignalLocalMetrics.java | 54 +++++++-- .../thoughtcrime/securesms/util/SqlUtil.java | 102 ++++++++++++++++ .../securesms/util/SqlUtilTest.java | 113 ++++++++++++++++++ .../api/SignalServiceMessageSender.java | 59 ++++++++- 10 files changed, 430 insertions(+), 52 deletions(-) diff --git a/app/src/main/java/org/thoughtcrime/securesms/AppCapabilities.java b/app/src/main/java/org/thoughtcrime/securesms/AppCapabilities.java index caf1587683..c23b855cf3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/AppCapabilities.java +++ b/app/src/main/java/org/thoughtcrime/securesms/AppCapabilities.java @@ -1,6 +1,5 @@ package org.thoughtcrime.securesms; -import org.thoughtcrime.securesms.util.FeatureFlags; import org.whispersystems.signalservice.api.account.AccountAttributes; public final class AppCapabilities { diff --git a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java index 836add58ef..497f9ff2a7 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java +++ b/app/src/main/java/org/thoughtcrime/securesms/crypto/storage/TextSecureSessionStore.java @@ -17,6 +17,7 @@ import org.whispersystems.libsignal.state.SessionRecord; import org.whispersystems.signalservice.api.SignalServiceSessionStore; import java.util.List; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -57,6 +58,10 @@ public class TextSecureSessionStore implements SignalServiceSessionStore { throw new NoSessionException(message); } + if (sessionRecords.stream().anyMatch(Objects::isNull)) { + throw new NoSessionException("Failed to find at least one session."); + } + return sessionRecords; } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/MessageSendLogDatabase.kt b/app/src/main/java/org/thoughtcrime/securesms/database/MessageSendLogDatabase.kt index 6a26b7f10f..2399eaf5d5 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/MessageSendLogDatabase.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/MessageSendLogDatabase.kt @@ -231,30 +231,31 @@ class MessageSendLogDatabase constructor(context: Context?, databaseHelper: SQLC val payloadId: Long = db.insert(PayloadTable.TABLE_NAME, null, payloadValues) + val recipientValues: MutableList = mutableListOf() recipients.forEach { recipientDevice -> recipientDevice.devices.forEach { device -> - val recipientValues = ContentValues().apply { + recipientValues += ContentValues().apply { put(RecipientTable.PAYLOAD_ID, payloadId) put(RecipientTable.RECIPIENT_ID, recipientDevice.recipientId.serialize()) put(RecipientTable.DEVICE, device) } - - db.insert(RecipientTable.TABLE_NAME, null, recipientValues) } } + SqlUtil.buildBulkInsert(RecipientTable.TABLE_NAME, arrayOf(RecipientTable.PAYLOAD_ID, RecipientTable.RECIPIENT_ID, RecipientTable.DEVICE), recipientValues) + .forEach { query -> db.execSQL(query.where, query.whereArgs) } + val messageValues: MutableList = mutableListOf() messageIds.forEach { messageId -> - val messageValues = ContentValues().apply { + messageValues += ContentValues().apply { put(MessageTable.PAYLOAD_ID, payloadId) put(MessageTable.MESSAGE_ID, messageId.id) put(MessageTable.IS_MMS, if (messageId.mms) 1 else 0) } - - db.insert(MessageTable.TABLE_NAME, null, messageValues) } + SqlUtil.buildBulkInsert(MessageTable.TABLE_NAME, arrayOf(MessageTable.PAYLOAD_ID, MessageTable.MESSAGE_ID, MessageTable.IS_MMS), messageValues) + .forEach { query -> db.execSQL(query.where, query.whereArgs) } db.setTransactionSuccessful() - return payloadId } finally { db.endTransaction() diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/SessionDatabase.java b/app/src/main/java/org/thoughtcrime/securesms/database/SessionDatabase.java index 27bb999a37..6d9f530d71 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/SessionDatabase.java +++ b/app/src/main/java/org/thoughtcrime/securesms/database/SessionDatabase.java @@ -20,6 +20,8 @@ import org.whispersystems.signalservice.api.push.SignalServiceAddress; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; @@ -81,33 +83,36 @@ public class SessionDatabase extends Database { } public @NonNull List load(@NonNull List addresses) { - SQLiteDatabase database = databaseHelper.getSignalReadableDatabase(); - List sessions = new ArrayList<>(addresses.size()); + SQLiteDatabase database = databaseHelper.getSignalReadableDatabase(); + String query = ADDRESS + " = ? AND " + DEVICE + " = ?"; + List args = new ArrayList<>(addresses.size()); - database.beginTransaction(); - try { - String[] projection = new String[] { RECORD }; - String query = ADDRESS + " = ? AND " + DEVICE + " = ?"; + HashMap sessions = new LinkedHashMap<>(addresses.size()); - for (SignalProtocolAddress address : addresses) { - String[] args = SqlUtil.buildArgs(address.getName(), address.getDeviceId()); + for (SignalProtocolAddress address : addresses) { + args.add(SqlUtil.buildArgs(address.getName(), address.getDeviceId())); + sessions.put(address, null); + } - try (Cursor cursor = database.query(TABLE_NAME, projection, query, args, null, null, null)) { - if (cursor.moveToFirst()) { - try { - sessions.add(new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD)))); - } catch (IOException e) { - Log.w(TAG, e); - } + String[] projection = new String[] { ADDRESS, DEVICE, RECORD }; + + for (SqlUtil.Query combinedQuery : SqlUtil.buildCustomCollectionQuery(query, args)) { + try (Cursor cursor = database.query(TABLE_NAME, projection, combinedQuery.getWhere(), combinedQuery.getWhereArgs(), null, null, null)) { + while (cursor.moveToNext()) { + String address = CursorUtil.requireString(cursor, ADDRESS); + int device = CursorUtil.requireInt(cursor, DEVICE); + + try { + SessionRecord record = new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD))); + sessions.put(new SignalProtocolAddress(address, device), record); + } catch (IOException e) { + Log.w(TAG, e); } } } - database.setTransactionSuccessful(); - } finally { - database.endTransaction(); } - return sessions; + return new ArrayList<>(sessions.values()); } public @NonNull List getAllFor(@NonNull String addressName) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushGroupSendJob.java b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushGroupSendJob.java index 1a848ed37a..528f062050 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushGroupSendJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushGroupSendJob.java @@ -163,7 +163,7 @@ public final class PushGroupSendJob extends PushSendJob { return; } - Recipient groupRecipient = message.getRecipient().fresh(); + Recipient groupRecipient = message.getRecipient().resolve(); if (!groupRecipient.isPushGroup()) { throw new MmsException("Message recipient isn't a group!"); @@ -188,8 +188,7 @@ public final class PushGroupSendJob extends PushSendJob { RecipientAccessList accessList = new RecipientAccessList(target); - List results = deliver(message, groupRecipient, target); - SignalLocalMetrics.GroupMessageSend.onNetworkFinished(messageId); + List results = deliver(message, groupRecipient, target); Log.i(TAG, JobLogger.format(this, "Finished send.")); List networkFailures = Stream.of(results).filter(SendMessageResult::isNetworkFailure).map(result -> new NetworkFailure(accessList.requireIdByAddress(result.getAddress()))).toList(); @@ -315,7 +314,6 @@ public final class PushGroupSendJob extends PushSendJob { .withExpiration(groupRecipient.getExpiresInSeconds()) .asGroupMessage(group) .build(); - SignalLocalMetrics.GroupMessageSend.onNetworkStarted(messageId); return GroupSendUtil.sendResendableDataMessage(context, groupRecipient.requireGroupId().requireV2(), destinations, isRecipientUpdate, ContentHint.IMPLICIT, new MessageId(messageId, true), groupDataMessage); } else { throw new UndeliverableMessageException("Messages can no longer be sent to V1 groups!"); @@ -347,7 +345,6 @@ public final class PushGroupSendJob extends PushSendJob { Log.i(TAG, JobLogger.format(this, "Beginning message send.")); - SignalLocalMetrics.GroupMessageSend.onNetworkStarted(messageId); return GroupSendUtil.sendResendableDataMessage(context, groupRecipient.getGroupId().transform(GroupId::requireV2).orNull(), destinations, diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/GroupSendUtil.java b/app/src/main/java/org/thoughtcrime/securesms/messages/GroupSendUtil.java index 504e1babc2..9c8817ea19 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/GroupSendUtil.java +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/GroupSendUtil.java @@ -21,6 +21,7 @@ import org.thoughtcrime.securesms.recipients.RecipientId; import org.thoughtcrime.securesms.recipients.RecipientUtil; import org.thoughtcrime.securesms.util.FeatureFlags; import org.thoughtcrime.securesms.util.RecipientAccessList; +import org.thoughtcrime.securesms.util.SignalLocalMetrics; import org.thoughtcrime.securesms.util.TextSecurePreferences; import org.whispersystems.libsignal.InvalidKeyException; import org.whispersystems.libsignal.InvalidRegistrationIdException; @@ -28,6 +29,8 @@ import org.whispersystems.libsignal.NoSessionException; import org.whispersystems.libsignal.util.guava.Optional; import org.whispersystems.signalservice.api.CancelationException; import org.whispersystems.signalservice.api.SignalServiceMessageSender; +import org.whispersystems.signalservice.api.SignalServiceMessageSender.LegacyGroupEvents; +import org.whispersystems.signalservice.api.SignalServiceMessageSender.SenderKeyGroupEvents; import org.whispersystems.signalservice.api.crypto.ContentHint; import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess; import org.whispersystems.signalservice.api.crypto.UnidentifiedAccessPair; @@ -83,7 +86,7 @@ public final class GroupSendUtil { @NonNull SignalServiceDataMessage message) throws IOException, UntrustedIdentityException { - return sendMessage(context, groupId, allTargets, isRecipientUpdate, DataSendOperation.resendable(message, contentHint, messageId), null); + return sendMessage(context, groupId, messageId, allTargets, isRecipientUpdate, DataSendOperation.resendable(message, contentHint, messageId), null); } /** @@ -104,7 +107,7 @@ public final class GroupSendUtil { @NonNull SignalServiceDataMessage message) throws IOException, UntrustedIdentityException { - return sendMessage(context, groupId, allTargets, isRecipientUpdate, DataSendOperation.unresendable(message, contentHint), null); + return sendMessage(context, groupId, null, allTargets, isRecipientUpdate, DataSendOperation.unresendable(message, contentHint), null); } /** @@ -121,7 +124,7 @@ public final class GroupSendUtil { @Nullable CancelationSignal cancelationSignal) throws IOException, UntrustedIdentityException { - return sendMessage(context, groupId, allTargets, false, new TypingSendOperation(message), cancelationSignal); + return sendMessage(context, groupId, null, allTargets, false, new TypingSendOperation(message), cancelationSignal); } /** @@ -137,7 +140,7 @@ public final class GroupSendUtil { @NonNull SignalServiceCallMessage message) throws IOException, UntrustedIdentityException { - return sendMessage(context, groupId, allTargets, false, new CallSendOperation(message), null); + return sendMessage(context, groupId, null, allTargets, false, new CallSendOperation(message), null); } /** @@ -150,6 +153,7 @@ public final class GroupSendUtil { @WorkerThread private static List sendMessage(@NonNull Context context, @Nullable GroupId.V2 groupId, + @Nullable MessageId relatedMessageId, @NonNull List allTargets, boolean isRecipientUpdate, @NonNull SendOperation sendOperation, @@ -205,6 +209,10 @@ public final class GroupSendUtil { senderKeyTargets.clear(); } + if (relatedMessageId != null) { + SignalLocalMetrics.GroupMessageSend.onSenderKeyStarted(relatedMessageId.getId()); + } + List allResults = new ArrayList<>(allTargets.size()); SignalServiceMessageSender messageSender = ApplicationDependencies.getSignalServiceMessageSender(); @@ -231,6 +239,10 @@ public final class GroupSendUtil { if (sendOperation.shouldIncludeInMessageLog()) { DatabaseFactory.getMessageLogDatabase(context).insertIfPossible(sendOperation.getSentTimestamp(), senderKeyTargets, results, sendOperation.getContentHint(), sendOperation.getRelatedMessageId()); } + + if (relatedMessageId != null) { + SignalLocalMetrics.GroupMessageSend.onSenderKeyMslInserted(relatedMessageId.getId()); + } } catch (InvalidUnidentifiedAccessHeaderException e) { Log.w(TAG, "Someone had a bad UD header. Falling back to legacy sends.", e); legacyTargets.addAll(senderKeyTargets); @@ -244,6 +256,12 @@ public final class GroupSendUtil { Log.w(TAG, "Invalid registrationId. Falling back to legacy sends.", e); legacyTargets.addAll(senderKeyTargets); } + } else if (relatedMessageId != null) { + SignalLocalMetrics.GroupMessageSend.onSenderKeyShared(relatedMessageId.getId()); + SignalLocalMetrics.GroupMessageSend.onSenderKeyEncrypted(relatedMessageId.getId()); + SignalLocalMetrics.GroupMessageSend.onSenderKeyMessageSent(relatedMessageId.getId()); + SignalLocalMetrics.GroupMessageSend.onSenderKeySyncSent(relatedMessageId.getId()); + SignalLocalMetrics.GroupMessageSend.onSenderKeyMslInserted(relatedMessageId.getId()); } if (cancelationSignal != null && cancelationSignal.isCanceled()) { @@ -285,6 +303,9 @@ public final class GroupSendUtil { int successCount = (int) results.stream().filter(SendMessageResult::isSuccess).count(); Log.d(TAG, "Successfully sent using 1:1 to " + successCount + "/" + targets.size() + " legacy targets."); + } else if (relatedMessageId != null) { + SignalLocalMetrics.GroupMessageSend.onLegacyMessageSent(relatedMessageId.getId()); + SignalLocalMetrics.GroupMessageSend.onLegacySyncFinished(relatedMessageId.getId()); } if (unregisteredTargets.size() > 0) { @@ -361,7 +382,8 @@ public final class GroupSendUtil { boolean isRecipientUpdate) throws NoSessionException, UntrustedIdentityException, InvalidKeyException, IOException, InvalidRegistrationIdException { - return messageSender.sendGroupDataMessage(distributionId, targets, access, isRecipientUpdate, contentHint, message); + SenderKeyGroupEvents listener = relatedMessageId != null ? new SenderKeyMetricEventListener(relatedMessageId.getId()) : SenderKeyGroupEvents.EMPTY; + return messageSender.sendGroupDataMessage(distributionId, targets, access, isRecipientUpdate, contentHint, message, listener); } @Override @@ -373,7 +395,8 @@ public final class GroupSendUtil { @Nullable CancelationSignal cancelationSignal) throws IOException, UntrustedIdentityException { - return messageSender.sendDataMessage(targets, access, isRecipientUpdate, contentHint, message, partialListener, cancelationSignal); + LegacyGroupEvents listener = relatedMessageId != null ? new LegacyMetricEventListener(relatedMessageId.getId()) : LegacyGroupEvents.EMPTY; + return messageSender.sendDataMessage(targets, access, isRecipientUpdate, contentHint, message, listener, partialListener, cancelationSignal); } @Override @@ -507,6 +530,54 @@ public final class GroupSendUtil { } } + private static final class SenderKeyMetricEventListener implements SenderKeyGroupEvents { + + private final long messageId; + + private SenderKeyMetricEventListener(long messageId) { + this.messageId = messageId; + } + + @Override + public void onSenderKeyShared() { + SignalLocalMetrics.GroupMessageSend.onSenderKeyShared(messageId); + } + + @Override + public void onMessageEncrypted() { + SignalLocalMetrics.GroupMessageSend.onSenderKeyEncrypted(messageId); + } + + @Override + public void onMessageSent() { + SignalLocalMetrics.GroupMessageSend.onSenderKeyMessageSent(messageId); + } + + @Override + public void onSyncMessageSent() { + SignalLocalMetrics.GroupMessageSend.onSenderKeySyncSent(messageId); + } + } + + private static final class LegacyMetricEventListener implements LegacyGroupEvents { + + private final long messageId; + + private LegacyMetricEventListener(long messageId) { + this.messageId = messageId; + } + + @Override + public void onMessageSent() { + SignalLocalMetrics.GroupMessageSend.onLegacyMessageSent(messageId); + } + + @Override + public void onSyncMessageSent() { + SignalLocalMetrics.GroupMessageSend.onLegacySyncFinished(messageId); + } + } + /** * Little utility wrapper that lets us get the various different slices of recipient models that we need for different methods. */ diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/SignalLocalMetrics.java b/app/src/main/java/org/thoughtcrime/securesms/util/SignalLocalMetrics.java index 89c6bf8085..fa053daa26 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/util/SignalLocalMetrics.java +++ b/app/src/main/java/org/thoughtcrime/securesms/util/SignalLocalMetrics.java @@ -178,12 +178,18 @@ public final class SignalLocalMetrics { public static final class GroupMessageSend { private static final String NAME = "group-message-send"; - private static final String SPLIT_DB_INSERT = "db-insert"; - private static final String SPLIT_JOB_ENQUEUE = "job-enqueue"; - private static final String SPLIT_JOB_PRE_NETWORK = "job-pre-network"; - private static final String SPLIT_NETWORK = "network"; - private static final String SPLIT_JOB_POST_NETWORK = "job-post-network"; - private static final String SPLIT_UI_UPDATE = "ui-update"; + private static final String SPLIT_DB_INSERT = "db-insert"; + private static final String SPLIT_JOB_ENQUEUE = "job-enqueue"; + private static final String SPLIT_JOB_PRE_NETWORK = "job-pre-network"; + private static final String SPLIT_SENDER_KEY_SHARED = "sk-shared"; + private static final String SPLIT_ENCRYPTION = "encryption"; + private static final String SPLIT_NETWORK_SENDER_KEY = "network-sk"; + private static final String SPLIT_NETWORK_SENDER_KEY_SYNC = "network-sk-sync"; + private static final String SPLIT_MSL_SENDER_KEY = "msl-sk"; + private static final String SPLIT_NETWORK_LEGACY = "network-legacy"; + private static final String SPLIT_NETWORK_LEGACY_SYNC = "network-legacy-sync"; + private static final String SPLIT_JOB_POST_NETWORK = "job-post-network"; + private static final String SPLIT_UI_UPDATE = "ui-update"; private static final Map ID_MAP = new HashMap<>(); @@ -205,14 +211,44 @@ public final class SignalLocalMetrics { LocalMetrics.getInstance().split(requireId(messageId), SPLIT_JOB_ENQUEUE); } - public static void onNetworkStarted(long messageId) { + public static void onSenderKeyStarted(long messageId) { if (!ID_MAP.containsKey(messageId)) return; LocalMetrics.getInstance().split(requireId(messageId), SPLIT_JOB_PRE_NETWORK); } - public static void onNetworkFinished(long messageId) { + public static void onSenderKeyShared(long messageId) { if (!ID_MAP.containsKey(messageId)) return; - LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK); + LocalMetrics.getInstance().split(requireId(messageId), SPLIT_SENDER_KEY_SHARED); + } + + public static void onSenderKeyEncrypted(long messageId) { + if (!ID_MAP.containsKey(messageId)) return; + LocalMetrics.getInstance().split(requireId(messageId), SPLIT_ENCRYPTION); + } + + public static void onSenderKeyMessageSent(long messageId) { + if (!ID_MAP.containsKey(messageId)) return; + LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK_SENDER_KEY); + } + + public static void onSenderKeySyncSent(long messageId) { + if (!ID_MAP.containsKey(messageId)) return; + LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK_SENDER_KEY_SYNC); + } + + public static void onSenderKeyMslInserted(long messageId) { + if (!ID_MAP.containsKey(messageId)) return; + LocalMetrics.getInstance().split(requireId(messageId), SPLIT_MSL_SENDER_KEY); + } + + public static void onLegacyMessageSent(long messageId) { + if (!ID_MAP.containsKey(messageId)) return; + LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK_LEGACY); + } + + public static void onLegacySyncFinished(long messageId) { + if (!ID_MAP.containsKey(messageId)) return; + LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK_LEGACY_SYNC); } public static void onJobFinished(long messageId) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/SqlUtil.java b/app/src/main/java/org/thoughtcrime/securesms/util/SqlUtil.java index 6d01caaaa9..f08d0f0ae3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/util/SqlUtil.java +++ b/app/src/main/java/org/thoughtcrime/securesms/util/SqlUtil.java @@ -4,6 +4,7 @@ import android.content.ContentValues; import android.database.Cursor; import androidx.annotation.NonNull; +import androidx.annotation.VisibleForTesting; import com.annimon.stream.Stream; @@ -19,9 +20,13 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; public final class SqlUtil { + /** The maximum number of arguments (i.e. question marks) allowed in a SQL statement. */ + private static final int MAX_QUERY_ARGS = 999; + private SqlUtil() {} public static boolean tableExists(@NonNull SQLiteDatabase db, @NonNull String table) { @@ -155,6 +160,41 @@ public final class SqlUtil { return new Query(column + " IN (" + query.toString() + ")", buildArgs(args)); } + public static @NonNull List buildCustomCollectionQuery(@NonNull String query, @NonNull List argList) { + return buildCustomCollectionQuery(query, argList, MAX_QUERY_ARGS); + } + + @VisibleForTesting + static @NonNull List buildCustomCollectionQuery(@NonNull String query, @NonNull List argList, int maxQueryArgs) { + int batchSize = maxQueryArgs / argList.get(0).length; + + return Util.chunk(argList, batchSize) + .stream() + .map(argBatch -> buildSingleCustomCollectionQuery(query, argBatch)) + .collect(Collectors.toList()); + } + + private static @NonNull Query buildSingleCustomCollectionQuery(@NonNull String query, @NonNull List argList) { + StringBuilder outputQuery = new StringBuilder(); + String[] outputArgs = new String[argList.get(0).length * argList.size()]; + int argPosition = 0; + + for (int i = 0, len = argList.size(); i < len; i++) { + outputQuery.append("(").append(query).append(")"); + if (i < len - 1) { + outputQuery.append(" OR "); + } + + String[] args = argList.get(i); + for (String arg : args) { + outputArgs[argPosition] = arg; + argPosition++; + } + } + + return new Query(outputQuery.toString(), outputArgs); + } + public static @NonNull Query buildQuery(@NonNull String where, @NonNull Object... args) { return new SqlUtil.Query(where, SqlUtil.buildArgs(args)); } @@ -168,6 +208,68 @@ public final class SqlUtil { return output; } + public static List buildBulkInsert(@NonNull String tableName, @NonNull String[] columns, List contentValues) { + return buildBulkInsert(tableName, columns, contentValues, MAX_QUERY_ARGS); + } + + @VisibleForTesting + static List buildBulkInsert(@NonNull String tableName, @NonNull String[] columns, List contentValues, int maxQueryArgs) { + int batchSize = maxQueryArgs / columns.length; + + return Util.chunk(contentValues, batchSize) + .stream() + .map(batch -> buildSingleBulkInsert(tableName, columns, batch)) + .collect(Collectors.toList()); + } + + private static Query buildSingleBulkInsert(@NonNull String tableName, @NonNull String[] columns, List contentValues) { + StringBuilder builder = new StringBuilder(); + builder.append("INSERT INTO ").append(tableName).append(" ("); + + for (int i = 0; i < columns.length; i++) { + builder.append(columns[i]); + if (i < columns.length - 1) { + builder.append(", "); + } + } + + builder.append(") VALUES "); + + StringBuilder placeholder = new StringBuilder(); + placeholder.append("("); + + for (int i = 0; i < columns.length; i++) { + placeholder.append("?"); + if (i < columns.length - 1) { + placeholder.append(", "); + } + } + + placeholder.append(")"); + + + for (int i = 0, len = contentValues.size(); i < len; i++) { + builder.append(placeholder); + if (i < len - 1) { + builder.append(", "); + } + } + + String query = builder.toString(); + String[] args = new String[columns.length * contentValues.size()]; + + int i = 0; + for (ContentValues values : contentValues) { + for (String column : columns) { + Object value = values.get(column); + args[i] = value != null ? values.get(column).toString() : "null"; + i++; + } + } + + return new Query(query, args); + } + public static class Query { private final String where; private final String[] whereArgs; diff --git a/app/src/test/java/org/thoughtcrime/securesms/util/SqlUtilTest.java b/app/src/test/java/org/thoughtcrime/securesms/util/SqlUtilTest.java index f6dcd577a8..0312de1a93 100644 --- a/app/src/test/java/org/thoughtcrime/securesms/util/SqlUtilTest.java +++ b/app/src/test/java/org/thoughtcrime/securesms/util/SqlUtilTest.java @@ -10,6 +10,7 @@ import org.robolectric.annotation.Config; import org.thoughtcrime.securesms.recipients.Recipient; import org.thoughtcrime.securesms.recipients.RecipientId; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -126,6 +127,48 @@ public final class SqlUtilTest { SqlUtil.buildCollectionQuery("a", Collections.emptyList()); } + @Test + public void buildCustomCollectionQuery_single_singleBatch() { + List args = new ArrayList<>(); + args.add(SqlUtil.buildArgs(1, 2)); + + List queries = SqlUtil.buildCustomCollectionQuery("a = ? AND b = ?", args); + + assertEquals(1, queries.size()); + assertEquals("(a = ? AND b = ?)", queries.get(0).getWhere()); + assertArrayEquals(new String[] { "1", "2" }, queries.get(0).getWhereArgs()); + } + + @Test + public void buildCustomCollectionQuery_multiple_singleBatch() { + List args = new ArrayList<>(); + args.add(SqlUtil.buildArgs(1, 2)); + args.add(SqlUtil.buildArgs(3, 4)); + args.add(SqlUtil.buildArgs(5, 6)); + + List queries = SqlUtil.buildCustomCollectionQuery("a = ? AND b = ?", args); + + assertEquals(1, queries.size()); + assertEquals("(a = ? AND b = ?) OR (a = ? AND b = ?) OR (a = ? AND b = ?)", queries.get(0).getWhere()); + assertArrayEquals(new String[] { "1", "2", "3", "4", "5", "6" }, queries.get(0).getWhereArgs()); + } + + @Test + public void buildCustomCollectionQuery_twoBatches() { + List args = new ArrayList<>(); + args.add(SqlUtil.buildArgs(1, 2)); + args.add(SqlUtil.buildArgs(3, 4)); + args.add(SqlUtil.buildArgs(5, 6)); + + List queries = SqlUtil.buildCustomCollectionQuery("a = ? AND b = ?", args, 4); + + assertEquals(2, queries.size()); + assertEquals("(a = ? AND b = ?) OR (a = ? AND b = ?)", queries.get(0).getWhere()); + assertArrayEquals(new String[] { "1", "2", "3", "4" }, queries.get(0).getWhereArgs()); + assertEquals("(a = ? AND b = ?)", queries.get(1).getWhere()); + assertArrayEquals(new String[] { "5", "6" }, queries.get(1).getWhereArgs()); + } + @Test public void splitStatements_singleStatement() { List result = SqlUtil.splitStatements("SELECT * FROM foo;\n"); @@ -143,4 +186,74 @@ public final class SqlUtilTest { List result = SqlUtil.splitStatements("SELECT * FROM foo;\n\nSELECT * FROM bar;\n"); assertEquals(Arrays.asList("SELECT * FROM foo", "SELECT * FROM bar"), result); } + + @Test + public void buildBulkInsert_single_singleBatch() { + List contentValues = new ArrayList<>(); + + ContentValues cv1 = new ContentValues(); + cv1.put("a", 1); + cv1.put("b", 2); + + contentValues.add(cv1); + + List output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b"}, contentValues); + + assertEquals(1, output.size()); + assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?)", output.get(0).getWhere()); + assertArrayEquals(new String[] { "1", "2" }, output.get(0).getWhereArgs()); + } + + @Test + public void buildBulkInsert_multiple_singleBatch() { + List contentValues = new ArrayList<>(); + + ContentValues cv1 = new ContentValues(); + cv1.put("a", 1); + cv1.put("b", 2); + + ContentValues cv2 = new ContentValues(); + cv2.put("a", 3); + cv2.put("b", 4); + + contentValues.add(cv1); + contentValues.add(cv2); + + List output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b"}, contentValues); + + assertEquals(1, output.size()); + assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?), (?, ?)", output.get(0).getWhere()); + assertArrayEquals(new String[] { "1", "2", "3", "4" }, output.get(0).getWhereArgs()); + } + + @Test + public void buildBulkInsert_twoBatches() { + List contentValues = new ArrayList<>(); + + ContentValues cv1 = new ContentValues(); + cv1.put("a", 1); + cv1.put("b", 2); + + ContentValues cv2 = new ContentValues(); + cv2.put("a", 3); + cv2.put("b", 4); + + ContentValues cv3 = new ContentValues(); + cv3.put("a", 5); + cv3.put("b", 6); + + contentValues.add(cv1); + contentValues.add(cv2); + contentValues.add(cv3); + + List output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b"}, contentValues, 4); + + assertEquals(2, output.size()); + + assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?), (?, ?)", output.get(0).getWhere()); + assertArrayEquals(new String[] { "1", "2", "3", "4" }, output.get(0).getWhereArgs()); + + assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?)", output.get(1).getWhere()); + assertArrayEquals(new String[] { "5", "6" }, output.get(1).getWhereArgs()); + } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java index 4737eba915..4d3fcd7fa6 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java @@ -253,7 +253,7 @@ public class SignalServiceMessageSender { throws IOException, UntrustedIdentityException, InvalidKeyException, NoSessionException, InvalidRegistrationIdException { Content content = createTypingContent(message); - sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp(), content, ContentHint.IMPLICIT, message.getGroupId().orNull(), true); + sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp(), content, ContentHint.IMPLICIT, message.getGroupId().orNull(), true, SenderKeyGroupEvents.EMPTY); } @@ -293,7 +293,7 @@ public class SignalServiceMessageSender { throws IOException, UntrustedIdentityException, InvalidKeyException, NoSessionException, InvalidRegistrationIdException { Content content = createCallContent(message); - return sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp().get(), content, ContentHint.IMPLICIT, message.getGroupId().get(), false); + return sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp().get(), content, ContentHint.IMPLICIT, message.getGroupId().get(), false, SenderKeyGroupEvents.EMPTY); } /** @@ -420,14 +420,17 @@ public class SignalServiceMessageSender { List unidentifiedAccess, boolean isRecipientUpdate, ContentHint contentHint, - SignalServiceDataMessage message) + SignalServiceDataMessage message, + SenderKeyGroupEvents sendEvents) throws IOException, UntrustedIdentityException, NoSessionException, InvalidKeyException, InvalidRegistrationIdException { Log.d(TAG, "[" + message.getTimestamp() + "] Sending a group data message to " + recipients.size() + " recipients."); Content content = createMessageContent(message); Optional groupId = message.getGroupId(); - List results = sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp(), content, contentHint, groupId.orNull(), false); + List results = sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp(), content, contentHint, groupId.orNull(), false, sendEvents); + + sendEvents.onMessageSent(); if (store.isMultiDevice()) { Content syncMessage = createMultiDeviceSentTranscriptContent(content, Optional.absent(), message.getTimestamp(), results, isRecipientUpdate); @@ -436,6 +439,8 @@ public class SignalServiceMessageSender { sendMessage(localAddress, Optional.absent(), message.getTimestamp(), syncMessageContent, false, null); } + sendEvents.onSyncMessageSent(); + return results; } @@ -450,6 +455,7 @@ public class SignalServiceMessageSender { boolean isRecipientUpdate, ContentHint contentHint, SignalServiceDataMessage message, + LegacyGroupEvents sendEvents, PartialSendCompleteListener partialListener, CancelationSignal cancelationSignal) throws IOException, UntrustedIdentityException @@ -462,6 +468,8 @@ public class SignalServiceMessageSender { List results = sendMessage(recipients, getTargetUnidentifiedAccess(unidentifiedAccess), timestamp, envelopeContent, false, partialListener, cancelationSignal); boolean needsSyncInResults = false; + sendEvents.onMessageSent(); + for (SendMessageResult result : results) { if (result.getSuccess() != null && result.getSuccess().isNeedsSync()) { needsSyncInResults = true; @@ -481,6 +489,8 @@ public class SignalServiceMessageSender { sendMessage(localAddress, Optional.absent(), timestamp, syncMessageContent, false, null); } + sendEvents.onSyncMessageSent(); + return results; } @@ -1673,7 +1683,8 @@ public class SignalServiceMessageSender { Content content, ContentHint contentHint, byte[] groupId, - boolean online) + boolean online, + SenderKeyGroupEvents sendEvents) throws IOException, UntrustedIdentityException, NoSessionException, InvalidKeyException, InvalidRegistrationIdException { if (recipients.isEmpty()) { @@ -1751,6 +1762,8 @@ public class SignalServiceMessageSender { } } + sendEvents.onSenderKeyShared(); + SignalServiceCipher cipher = new SignalServiceCipher(localAddress, store, sessionLock, null); SenderCertificate senderCertificate = unidentifiedAccess.get(0).getUnidentifiedCertificate(); @@ -1761,6 +1774,8 @@ public class SignalServiceMessageSender { throw new UntrustedIdentityException("Untrusted during group encrypt", e.getName(), e.getUntrustedIdentity()); } + sendEvents.onMessageEncrypted(); + byte[] joinedUnidentifiedAccess = new byte[16]; for (UnidentifiedAccess access : unidentifiedAccess) { joinedUnidentifiedAccess = ByteArrayUtil.xor(joinedUnidentifiedAccess, access.getUnidentifiedAccessKey()); @@ -2108,4 +2123,38 @@ public class SignalServiceMessageSender { void onMessageSent(); void onSyncMessageSent(); } + + public interface SenderKeyGroupEvents { + SenderKeyGroupEvents EMPTY = new SenderKeyGroupEvents() { + @Override + public void onSenderKeyShared() { } + + @Override + public void onMessageEncrypted() { } + + @Override + public void onMessageSent() { } + + @Override + public void onSyncMessageSent() { } + }; + + void onSenderKeyShared(); + void onMessageEncrypted(); + void onMessageSent(); + void onSyncMessageSent(); + } + + public interface LegacyGroupEvents { + LegacyGroupEvents EMPTY = new LegacyGroupEvents() { + @Override + public void onMessageSent() { } + + @Override + public void onSyncMessageSent() { } + }; + + void onMessageSent(); + void onSyncMessageSent(); + } }