Add measurements, improve MSL insert.

This commit is contained in:
Greyson Parrelli
2021-08-30 15:07:03 -04:00
parent 1241f4c0e9
commit bca2205945
10 changed files with 430 additions and 52 deletions

View File

@@ -1,6 +1,5 @@
package org.thoughtcrime.securesms;
import org.thoughtcrime.securesms.util.FeatureFlags;
import org.whispersystems.signalservice.api.account.AccountAttributes;
public final class AppCapabilities {

View File

@@ -17,6 +17,7 @@ import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@@ -57,6 +58,10 @@ public class TextSecureSessionStore implements SignalServiceSessionStore {
throw new NoSessionException(message);
}
if (sessionRecords.stream().anyMatch(Objects::isNull)) {
throw new NoSessionException("Failed to find at least one session.");
}
return sessionRecords;
}
}

View File

@@ -231,30 +231,31 @@ class MessageSendLogDatabase constructor(context: Context?, databaseHelper: SQLC
val payloadId: Long = db.insert(PayloadTable.TABLE_NAME, null, payloadValues)
val recipientValues: MutableList<ContentValues> = mutableListOf()
recipients.forEach { recipientDevice ->
recipientDevice.devices.forEach { device ->
val recipientValues = ContentValues().apply {
recipientValues += ContentValues().apply {
put(RecipientTable.PAYLOAD_ID, payloadId)
put(RecipientTable.RECIPIENT_ID, recipientDevice.recipientId.serialize())
put(RecipientTable.DEVICE, device)
}
db.insert(RecipientTable.TABLE_NAME, null, recipientValues)
}
}
SqlUtil.buildBulkInsert(RecipientTable.TABLE_NAME, arrayOf(RecipientTable.PAYLOAD_ID, RecipientTable.RECIPIENT_ID, RecipientTable.DEVICE), recipientValues)
.forEach { query -> db.execSQL(query.where, query.whereArgs) }
val messageValues: MutableList<ContentValues> = mutableListOf()
messageIds.forEach { messageId ->
val messageValues = ContentValues().apply {
messageValues += ContentValues().apply {
put(MessageTable.PAYLOAD_ID, payloadId)
put(MessageTable.MESSAGE_ID, messageId.id)
put(MessageTable.IS_MMS, if (messageId.mms) 1 else 0)
}
db.insert(MessageTable.TABLE_NAME, null, messageValues)
}
SqlUtil.buildBulkInsert(MessageTable.TABLE_NAME, arrayOf(MessageTable.PAYLOAD_ID, MessageTable.MESSAGE_ID, MessageTable.IS_MMS), messageValues)
.forEach { query -> db.execSQL(query.where, query.whereArgs) }
db.setTransactionSuccessful()
return payloadId
} finally {
db.endTransaction()

View File

@@ -20,6 +20,8 @@ import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
@@ -82,32 +84,35 @@ public class SessionDatabase extends Database {
public @NonNull List<SessionRecord> load(@NonNull List<SignalProtocolAddress> addresses) {
SQLiteDatabase database = databaseHelper.getSignalReadableDatabase();
List<SessionRecord> sessions = new ArrayList<>(addresses.size());
database.beginTransaction();
try {
String[] projection = new String[] { RECORD };
String query = ADDRESS + " = ? AND " + DEVICE + " = ?";
List<String[]> args = new ArrayList<>(addresses.size());
HashMap<SignalProtocolAddress, SessionRecord> sessions = new LinkedHashMap<>(addresses.size());
for (SignalProtocolAddress address : addresses) {
String[] args = SqlUtil.buildArgs(address.getName(), address.getDeviceId());
args.add(SqlUtil.buildArgs(address.getName(), address.getDeviceId()));
sessions.put(address, null);
}
String[] projection = new String[] { ADDRESS, DEVICE, RECORD };
for (SqlUtil.Query combinedQuery : SqlUtil.buildCustomCollectionQuery(query, args)) {
try (Cursor cursor = database.query(TABLE_NAME, projection, combinedQuery.getWhere(), combinedQuery.getWhereArgs(), null, null, null)) {
while (cursor.moveToNext()) {
String address = CursorUtil.requireString(cursor, ADDRESS);
int device = CursorUtil.requireInt(cursor, DEVICE);
try (Cursor cursor = database.query(TABLE_NAME, projection, query, args, null, null, null)) {
if (cursor.moveToFirst()) {
try {
sessions.add(new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD))));
SessionRecord record = new SessionRecord(cursor.getBlob(cursor.getColumnIndexOrThrow(RECORD)));
sessions.put(new SignalProtocolAddress(address, device), record);
} catch (IOException e) {
Log.w(TAG, e);
}
}
}
}
database.setTransactionSuccessful();
} finally {
database.endTransaction();
}
return sessions;
return new ArrayList<>(sessions.values());
}
public @NonNull List<SessionRow> getAllFor(@NonNull String addressName) {

View File

@@ -163,7 +163,7 @@ public final class PushGroupSendJob extends PushSendJob {
return;
}
Recipient groupRecipient = message.getRecipient().fresh();
Recipient groupRecipient = message.getRecipient().resolve();
if (!groupRecipient.isPushGroup()) {
throw new MmsException("Message recipient isn't a group!");
@@ -189,7 +189,6 @@ public final class PushGroupSendJob extends PushSendJob {
RecipientAccessList accessList = new RecipientAccessList(target);
List<SendMessageResult> results = deliver(message, groupRecipient, target);
SignalLocalMetrics.GroupMessageSend.onNetworkFinished(messageId);
Log.i(TAG, JobLogger.format(this, "Finished send."));
List<NetworkFailure> networkFailures = Stream.of(results).filter(SendMessageResult::isNetworkFailure).map(result -> new NetworkFailure(accessList.requireIdByAddress(result.getAddress()))).toList();
@@ -315,7 +314,6 @@ public final class PushGroupSendJob extends PushSendJob {
.withExpiration(groupRecipient.getExpiresInSeconds())
.asGroupMessage(group)
.build();
SignalLocalMetrics.GroupMessageSend.onNetworkStarted(messageId);
return GroupSendUtil.sendResendableDataMessage(context, groupRecipient.requireGroupId().requireV2(), destinations, isRecipientUpdate, ContentHint.IMPLICIT, new MessageId(messageId, true), groupDataMessage);
} else {
throw new UndeliverableMessageException("Messages can no longer be sent to V1 groups!");
@@ -347,7 +345,6 @@ public final class PushGroupSendJob extends PushSendJob {
Log.i(TAG, JobLogger.format(this, "Beginning message send."));
SignalLocalMetrics.GroupMessageSend.onNetworkStarted(messageId);
return GroupSendUtil.sendResendableDataMessage(context,
groupRecipient.getGroupId().transform(GroupId::requireV2).orNull(),
destinations,

View File

@@ -21,6 +21,7 @@ import org.thoughtcrime.securesms.recipients.RecipientId;
import org.thoughtcrime.securesms.recipients.RecipientUtil;
import org.thoughtcrime.securesms.util.FeatureFlags;
import org.thoughtcrime.securesms.util.RecipientAccessList;
import org.thoughtcrime.securesms.util.SignalLocalMetrics;
import org.thoughtcrime.securesms.util.TextSecurePreferences;
import org.whispersystems.libsignal.InvalidKeyException;
import org.whispersystems.libsignal.InvalidRegistrationIdException;
@@ -28,6 +29,8 @@ import org.whispersystems.libsignal.NoSessionException;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.signalservice.api.CancelationException;
import org.whispersystems.signalservice.api.SignalServiceMessageSender;
import org.whispersystems.signalservice.api.SignalServiceMessageSender.LegacyGroupEvents;
import org.whispersystems.signalservice.api.SignalServiceMessageSender.SenderKeyGroupEvents;
import org.whispersystems.signalservice.api.crypto.ContentHint;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccessPair;
@@ -83,7 +86,7 @@ public final class GroupSendUtil {
@NonNull SignalServiceDataMessage message)
throws IOException, UntrustedIdentityException
{
return sendMessage(context, groupId, allTargets, isRecipientUpdate, DataSendOperation.resendable(message, contentHint, messageId), null);
return sendMessage(context, groupId, messageId, allTargets, isRecipientUpdate, DataSendOperation.resendable(message, contentHint, messageId), null);
}
/**
@@ -104,7 +107,7 @@ public final class GroupSendUtil {
@NonNull SignalServiceDataMessage message)
throws IOException, UntrustedIdentityException
{
return sendMessage(context, groupId, allTargets, isRecipientUpdate, DataSendOperation.unresendable(message, contentHint), null);
return sendMessage(context, groupId, null, allTargets, isRecipientUpdate, DataSendOperation.unresendable(message, contentHint), null);
}
/**
@@ -121,7 +124,7 @@ public final class GroupSendUtil {
@Nullable CancelationSignal cancelationSignal)
throws IOException, UntrustedIdentityException
{
return sendMessage(context, groupId, allTargets, false, new TypingSendOperation(message), cancelationSignal);
return sendMessage(context, groupId, null, allTargets, false, new TypingSendOperation(message), cancelationSignal);
}
/**
@@ -137,7 +140,7 @@ public final class GroupSendUtil {
@NonNull SignalServiceCallMessage message)
throws IOException, UntrustedIdentityException
{
return sendMessage(context, groupId, allTargets, false, new CallSendOperation(message), null);
return sendMessage(context, groupId, null, allTargets, false, new CallSendOperation(message), null);
}
/**
@@ -150,6 +153,7 @@ public final class GroupSendUtil {
@WorkerThread
private static List<SendMessageResult> sendMessage(@NonNull Context context,
@Nullable GroupId.V2 groupId,
@Nullable MessageId relatedMessageId,
@NonNull List<Recipient> allTargets,
boolean isRecipientUpdate,
@NonNull SendOperation sendOperation,
@@ -205,6 +209,10 @@ public final class GroupSendUtil {
senderKeyTargets.clear();
}
if (relatedMessageId != null) {
SignalLocalMetrics.GroupMessageSend.onSenderKeyStarted(relatedMessageId.getId());
}
List<SendMessageResult> allResults = new ArrayList<>(allTargets.size());
SignalServiceMessageSender messageSender = ApplicationDependencies.getSignalServiceMessageSender();
@@ -231,6 +239,10 @@ public final class GroupSendUtil {
if (sendOperation.shouldIncludeInMessageLog()) {
DatabaseFactory.getMessageLogDatabase(context).insertIfPossible(sendOperation.getSentTimestamp(), senderKeyTargets, results, sendOperation.getContentHint(), sendOperation.getRelatedMessageId());
}
if (relatedMessageId != null) {
SignalLocalMetrics.GroupMessageSend.onSenderKeyMslInserted(relatedMessageId.getId());
}
} catch (InvalidUnidentifiedAccessHeaderException e) {
Log.w(TAG, "Someone had a bad UD header. Falling back to legacy sends.", e);
legacyTargets.addAll(senderKeyTargets);
@@ -244,6 +256,12 @@ public final class GroupSendUtil {
Log.w(TAG, "Invalid registrationId. Falling back to legacy sends.", e);
legacyTargets.addAll(senderKeyTargets);
}
} else if (relatedMessageId != null) {
SignalLocalMetrics.GroupMessageSend.onSenderKeyShared(relatedMessageId.getId());
SignalLocalMetrics.GroupMessageSend.onSenderKeyEncrypted(relatedMessageId.getId());
SignalLocalMetrics.GroupMessageSend.onSenderKeyMessageSent(relatedMessageId.getId());
SignalLocalMetrics.GroupMessageSend.onSenderKeySyncSent(relatedMessageId.getId());
SignalLocalMetrics.GroupMessageSend.onSenderKeyMslInserted(relatedMessageId.getId());
}
if (cancelationSignal != null && cancelationSignal.isCanceled()) {
@@ -285,6 +303,9 @@ public final class GroupSendUtil {
int successCount = (int) results.stream().filter(SendMessageResult::isSuccess).count();
Log.d(TAG, "Successfully sent using 1:1 to " + successCount + "/" + targets.size() + " legacy targets.");
} else if (relatedMessageId != null) {
SignalLocalMetrics.GroupMessageSend.onLegacyMessageSent(relatedMessageId.getId());
SignalLocalMetrics.GroupMessageSend.onLegacySyncFinished(relatedMessageId.getId());
}
if (unregisteredTargets.size() > 0) {
@@ -361,7 +382,8 @@ public final class GroupSendUtil {
boolean isRecipientUpdate)
throws NoSessionException, UntrustedIdentityException, InvalidKeyException, IOException, InvalidRegistrationIdException
{
return messageSender.sendGroupDataMessage(distributionId, targets, access, isRecipientUpdate, contentHint, message);
SenderKeyGroupEvents listener = relatedMessageId != null ? new SenderKeyMetricEventListener(relatedMessageId.getId()) : SenderKeyGroupEvents.EMPTY;
return messageSender.sendGroupDataMessage(distributionId, targets, access, isRecipientUpdate, contentHint, message, listener);
}
@Override
@@ -373,7 +395,8 @@ public final class GroupSendUtil {
@Nullable CancelationSignal cancelationSignal)
throws IOException, UntrustedIdentityException
{
return messageSender.sendDataMessage(targets, access, isRecipientUpdate, contentHint, message, partialListener, cancelationSignal);
LegacyGroupEvents listener = relatedMessageId != null ? new LegacyMetricEventListener(relatedMessageId.getId()) : LegacyGroupEvents.EMPTY;
return messageSender.sendDataMessage(targets, access, isRecipientUpdate, contentHint, message, listener, partialListener, cancelationSignal);
}
@Override
@@ -507,6 +530,54 @@ public final class GroupSendUtil {
}
}
private static final class SenderKeyMetricEventListener implements SenderKeyGroupEvents {
private final long messageId;
private SenderKeyMetricEventListener(long messageId) {
this.messageId = messageId;
}
@Override
public void onSenderKeyShared() {
SignalLocalMetrics.GroupMessageSend.onSenderKeyShared(messageId);
}
@Override
public void onMessageEncrypted() {
SignalLocalMetrics.GroupMessageSend.onSenderKeyEncrypted(messageId);
}
@Override
public void onMessageSent() {
SignalLocalMetrics.GroupMessageSend.onSenderKeyMessageSent(messageId);
}
@Override
public void onSyncMessageSent() {
SignalLocalMetrics.GroupMessageSend.onSenderKeySyncSent(messageId);
}
}
private static final class LegacyMetricEventListener implements LegacyGroupEvents {
private final long messageId;
private LegacyMetricEventListener(long messageId) {
this.messageId = messageId;
}
@Override
public void onMessageSent() {
SignalLocalMetrics.GroupMessageSend.onLegacyMessageSent(messageId);
}
@Override
public void onSyncMessageSent() {
SignalLocalMetrics.GroupMessageSend.onLegacySyncFinished(messageId);
}
}
/**
* Little utility wrapper that lets us get the various different slices of recipient models that we need for different methods.
*/

View File

@@ -181,7 +181,13 @@ public final class SignalLocalMetrics {
private static final String SPLIT_DB_INSERT = "db-insert";
private static final String SPLIT_JOB_ENQUEUE = "job-enqueue";
private static final String SPLIT_JOB_PRE_NETWORK = "job-pre-network";
private static final String SPLIT_NETWORK = "network";
private static final String SPLIT_SENDER_KEY_SHARED = "sk-shared";
private static final String SPLIT_ENCRYPTION = "encryption";
private static final String SPLIT_NETWORK_SENDER_KEY = "network-sk";
private static final String SPLIT_NETWORK_SENDER_KEY_SYNC = "network-sk-sync";
private static final String SPLIT_MSL_SENDER_KEY = "msl-sk";
private static final String SPLIT_NETWORK_LEGACY = "network-legacy";
private static final String SPLIT_NETWORK_LEGACY_SYNC = "network-legacy-sync";
private static final String SPLIT_JOB_POST_NETWORK = "job-post-network";
private static final String SPLIT_UI_UPDATE = "ui-update";
@@ -205,14 +211,44 @@ public final class SignalLocalMetrics {
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_JOB_ENQUEUE);
}
public static void onNetworkStarted(long messageId) {
public static void onSenderKeyStarted(long messageId) {
if (!ID_MAP.containsKey(messageId)) return;
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_JOB_PRE_NETWORK);
}
public static void onNetworkFinished(long messageId) {
public static void onSenderKeyShared(long messageId) {
if (!ID_MAP.containsKey(messageId)) return;
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK);
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_SENDER_KEY_SHARED);
}
public static void onSenderKeyEncrypted(long messageId) {
if (!ID_MAP.containsKey(messageId)) return;
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_ENCRYPTION);
}
public static void onSenderKeyMessageSent(long messageId) {
if (!ID_MAP.containsKey(messageId)) return;
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK_SENDER_KEY);
}
public static void onSenderKeySyncSent(long messageId) {
if (!ID_MAP.containsKey(messageId)) return;
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK_SENDER_KEY_SYNC);
}
public static void onSenderKeyMslInserted(long messageId) {
if (!ID_MAP.containsKey(messageId)) return;
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_MSL_SENDER_KEY);
}
public static void onLegacyMessageSent(long messageId) {
if (!ID_MAP.containsKey(messageId)) return;
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK_LEGACY);
}
public static void onLegacySyncFinished(long messageId) {
if (!ID_MAP.containsKey(messageId)) return;
LocalMetrics.getInstance().split(requireId(messageId), SPLIT_NETWORK_LEGACY_SYNC);
}
public static void onJobFinished(long messageId) {

View File

@@ -4,6 +4,7 @@ import android.content.ContentValues;
import android.database.Cursor;
import androidx.annotation.NonNull;
import androidx.annotation.VisibleForTesting;
import com.annimon.stream.Stream;
@@ -19,9 +20,13 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public final class SqlUtil {
/** The maximum number of arguments (i.e. question marks) allowed in a SQL statement. */
private static final int MAX_QUERY_ARGS = 999;
private SqlUtil() {}
public static boolean tableExists(@NonNull SQLiteDatabase db, @NonNull String table) {
@@ -155,6 +160,41 @@ public final class SqlUtil {
return new Query(column + " IN (" + query.toString() + ")", buildArgs(args));
}
public static @NonNull List<Query> buildCustomCollectionQuery(@NonNull String query, @NonNull List<String[]> argList) {
return buildCustomCollectionQuery(query, argList, MAX_QUERY_ARGS);
}
@VisibleForTesting
static @NonNull List<Query> buildCustomCollectionQuery(@NonNull String query, @NonNull List<String[]> argList, int maxQueryArgs) {
int batchSize = maxQueryArgs / argList.get(0).length;
return Util.chunk(argList, batchSize)
.stream()
.map(argBatch -> buildSingleCustomCollectionQuery(query, argBatch))
.collect(Collectors.toList());
}
private static @NonNull Query buildSingleCustomCollectionQuery(@NonNull String query, @NonNull List<String[]> argList) {
StringBuilder outputQuery = new StringBuilder();
String[] outputArgs = new String[argList.get(0).length * argList.size()];
int argPosition = 0;
for (int i = 0, len = argList.size(); i < len; i++) {
outputQuery.append("(").append(query).append(")");
if (i < len - 1) {
outputQuery.append(" OR ");
}
String[] args = argList.get(i);
for (String arg : args) {
outputArgs[argPosition] = arg;
argPosition++;
}
}
return new Query(outputQuery.toString(), outputArgs);
}
public static @NonNull Query buildQuery(@NonNull String where, @NonNull Object... args) {
return new SqlUtil.Query(where, SqlUtil.buildArgs(args));
}
@@ -168,6 +208,68 @@ public final class SqlUtil {
return output;
}
public static List<Query> buildBulkInsert(@NonNull String tableName, @NonNull String[] columns, List<ContentValues> contentValues) {
return buildBulkInsert(tableName, columns, contentValues, MAX_QUERY_ARGS);
}
@VisibleForTesting
static List<Query> buildBulkInsert(@NonNull String tableName, @NonNull String[] columns, List<ContentValues> contentValues, int maxQueryArgs) {
int batchSize = maxQueryArgs / columns.length;
return Util.chunk(contentValues, batchSize)
.stream()
.map(batch -> buildSingleBulkInsert(tableName, columns, batch))
.collect(Collectors.toList());
}
private static Query buildSingleBulkInsert(@NonNull String tableName, @NonNull String[] columns, List<ContentValues> contentValues) {
StringBuilder builder = new StringBuilder();
builder.append("INSERT INTO ").append(tableName).append(" (");
for (int i = 0; i < columns.length; i++) {
builder.append(columns[i]);
if (i < columns.length - 1) {
builder.append(", ");
}
}
builder.append(") VALUES ");
StringBuilder placeholder = new StringBuilder();
placeholder.append("(");
for (int i = 0; i < columns.length; i++) {
placeholder.append("?");
if (i < columns.length - 1) {
placeholder.append(", ");
}
}
placeholder.append(")");
for (int i = 0, len = contentValues.size(); i < len; i++) {
builder.append(placeholder);
if (i < len - 1) {
builder.append(", ");
}
}
String query = builder.toString();
String[] args = new String[columns.length * contentValues.size()];
int i = 0;
for (ContentValues values : contentValues) {
for (String column : columns) {
Object value = values.get(column);
args[i] = value != null ? values.get(column).toString() : "null";
i++;
}
}
return new Query(query, args);
}
public static class Query {
private final String where;
private final String[] whereArgs;

View File

@@ -10,6 +10,7 @@ import org.robolectric.annotation.Config;
import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@@ -126,6 +127,48 @@ public final class SqlUtilTest {
SqlUtil.buildCollectionQuery("a", Collections.emptyList());
}
@Test
public void buildCustomCollectionQuery_single_singleBatch() {
List<String[]> args = new ArrayList<>();
args.add(SqlUtil.buildArgs(1, 2));
List<SqlUtil.Query> queries = SqlUtil.buildCustomCollectionQuery("a = ? AND b = ?", args);
assertEquals(1, queries.size());
assertEquals("(a = ? AND b = ?)", queries.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2" }, queries.get(0).getWhereArgs());
}
@Test
public void buildCustomCollectionQuery_multiple_singleBatch() {
List<String[]> args = new ArrayList<>();
args.add(SqlUtil.buildArgs(1, 2));
args.add(SqlUtil.buildArgs(3, 4));
args.add(SqlUtil.buildArgs(5, 6));
List<SqlUtil.Query> queries = SqlUtil.buildCustomCollectionQuery("a = ? AND b = ?", args);
assertEquals(1, queries.size());
assertEquals("(a = ? AND b = ?) OR (a = ? AND b = ?) OR (a = ? AND b = ?)", queries.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4", "5", "6" }, queries.get(0).getWhereArgs());
}
@Test
public void buildCustomCollectionQuery_twoBatches() {
List<String[]> args = new ArrayList<>();
args.add(SqlUtil.buildArgs(1, 2));
args.add(SqlUtil.buildArgs(3, 4));
args.add(SqlUtil.buildArgs(5, 6));
List<SqlUtil.Query> queries = SqlUtil.buildCustomCollectionQuery("a = ? AND b = ?", args, 4);
assertEquals(2, queries.size());
assertEquals("(a = ? AND b = ?) OR (a = ? AND b = ?)", queries.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4" }, queries.get(0).getWhereArgs());
assertEquals("(a = ? AND b = ?)", queries.get(1).getWhere());
assertArrayEquals(new String[] { "5", "6" }, queries.get(1).getWhereArgs());
}
@Test
public void splitStatements_singleStatement() {
List<String> result = SqlUtil.splitStatements("SELECT * FROM foo;\n");
@@ -143,4 +186,74 @@ public final class SqlUtilTest {
List<String> result = SqlUtil.splitStatements("SELECT * FROM foo;\n\nSELECT * FROM bar;\n");
assertEquals(Arrays.asList("SELECT * FROM foo", "SELECT * FROM bar"), result);
}
@Test
public void buildBulkInsert_single_singleBatch() {
List<ContentValues> contentValues = new ArrayList<>();
ContentValues cv1 = new ContentValues();
cv1.put("a", 1);
cv1.put("b", 2);
contentValues.add(cv1);
List<SqlUtil.Query> output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b"}, contentValues);
assertEquals(1, output.size());
assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?)", output.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2" }, output.get(0).getWhereArgs());
}
@Test
public void buildBulkInsert_multiple_singleBatch() {
List<ContentValues> contentValues = new ArrayList<>();
ContentValues cv1 = new ContentValues();
cv1.put("a", 1);
cv1.put("b", 2);
ContentValues cv2 = new ContentValues();
cv2.put("a", 3);
cv2.put("b", 4);
contentValues.add(cv1);
contentValues.add(cv2);
List<SqlUtil.Query> output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b"}, contentValues);
assertEquals(1, output.size());
assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?), (?, ?)", output.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4" }, output.get(0).getWhereArgs());
}
@Test
public void buildBulkInsert_twoBatches() {
List<ContentValues> contentValues = new ArrayList<>();
ContentValues cv1 = new ContentValues();
cv1.put("a", 1);
cv1.put("b", 2);
ContentValues cv2 = new ContentValues();
cv2.put("a", 3);
cv2.put("b", 4);
ContentValues cv3 = new ContentValues();
cv3.put("a", 5);
cv3.put("b", 6);
contentValues.add(cv1);
contentValues.add(cv2);
contentValues.add(cv3);
List<SqlUtil.Query> output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b"}, contentValues, 4);
assertEquals(2, output.size());
assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?), (?, ?)", output.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4" }, output.get(0).getWhereArgs());
assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?)", output.get(1).getWhere());
assertArrayEquals(new String[] { "5", "6" }, output.get(1).getWhereArgs());
}
}

View File

@@ -253,7 +253,7 @@ public class SignalServiceMessageSender {
throws IOException, UntrustedIdentityException, InvalidKeyException, NoSessionException, InvalidRegistrationIdException
{
Content content = createTypingContent(message);
sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp(), content, ContentHint.IMPLICIT, message.getGroupId().orNull(), true);
sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp(), content, ContentHint.IMPLICIT, message.getGroupId().orNull(), true, SenderKeyGroupEvents.EMPTY);
}
@@ -293,7 +293,7 @@ public class SignalServiceMessageSender {
throws IOException, UntrustedIdentityException, InvalidKeyException, NoSessionException, InvalidRegistrationIdException
{
Content content = createCallContent(message);
return sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp().get(), content, ContentHint.IMPLICIT, message.getGroupId().get(), false);
return sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp().get(), content, ContentHint.IMPLICIT, message.getGroupId().get(), false, SenderKeyGroupEvents.EMPTY);
}
/**
@@ -420,14 +420,17 @@ public class SignalServiceMessageSender {
List<UnidentifiedAccess> unidentifiedAccess,
boolean isRecipientUpdate,
ContentHint contentHint,
SignalServiceDataMessage message)
SignalServiceDataMessage message,
SenderKeyGroupEvents sendEvents)
throws IOException, UntrustedIdentityException, NoSessionException, InvalidKeyException, InvalidRegistrationIdException
{
Log.d(TAG, "[" + message.getTimestamp() + "] Sending a group data message to " + recipients.size() + " recipients.");
Content content = createMessageContent(message);
Optional<byte[]> groupId = message.getGroupId();
List<SendMessageResult> results = sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp(), content, contentHint, groupId.orNull(), false);
List<SendMessageResult> results = sendGroupMessage(distributionId, recipients, unidentifiedAccess, message.getTimestamp(), content, contentHint, groupId.orNull(), false, sendEvents);
sendEvents.onMessageSent();
if (store.isMultiDevice()) {
Content syncMessage = createMultiDeviceSentTranscriptContent(content, Optional.absent(), message.getTimestamp(), results, isRecipientUpdate);
@@ -436,6 +439,8 @@ public class SignalServiceMessageSender {
sendMessage(localAddress, Optional.absent(), message.getTimestamp(), syncMessageContent, false, null);
}
sendEvents.onSyncMessageSent();
return results;
}
@@ -450,6 +455,7 @@ public class SignalServiceMessageSender {
boolean isRecipientUpdate,
ContentHint contentHint,
SignalServiceDataMessage message,
LegacyGroupEvents sendEvents,
PartialSendCompleteListener partialListener,
CancelationSignal cancelationSignal)
throws IOException, UntrustedIdentityException
@@ -462,6 +468,8 @@ public class SignalServiceMessageSender {
List<SendMessageResult> results = sendMessage(recipients, getTargetUnidentifiedAccess(unidentifiedAccess), timestamp, envelopeContent, false, partialListener, cancelationSignal);
boolean needsSyncInResults = false;
sendEvents.onMessageSent();
for (SendMessageResult result : results) {
if (result.getSuccess() != null && result.getSuccess().isNeedsSync()) {
needsSyncInResults = true;
@@ -481,6 +489,8 @@ public class SignalServiceMessageSender {
sendMessage(localAddress, Optional.absent(), timestamp, syncMessageContent, false, null);
}
sendEvents.onSyncMessageSent();
return results;
}
@@ -1673,7 +1683,8 @@ public class SignalServiceMessageSender {
Content content,
ContentHint contentHint,
byte[] groupId,
boolean online)
boolean online,
SenderKeyGroupEvents sendEvents)
throws IOException, UntrustedIdentityException, NoSessionException, InvalidKeyException, InvalidRegistrationIdException
{
if (recipients.isEmpty()) {
@@ -1751,6 +1762,8 @@ public class SignalServiceMessageSender {
}
}
sendEvents.onSenderKeyShared();
SignalServiceCipher cipher = new SignalServiceCipher(localAddress, store, sessionLock, null);
SenderCertificate senderCertificate = unidentifiedAccess.get(0).getUnidentifiedCertificate();
@@ -1761,6 +1774,8 @@ public class SignalServiceMessageSender {
throw new UntrustedIdentityException("Untrusted during group encrypt", e.getName(), e.getUntrustedIdentity());
}
sendEvents.onMessageEncrypted();
byte[] joinedUnidentifiedAccess = new byte[16];
for (UnidentifiedAccess access : unidentifiedAccess) {
joinedUnidentifiedAccess = ByteArrayUtil.xor(joinedUnidentifiedAccess, access.getUnidentifiedAccessKey());
@@ -2108,4 +2123,38 @@ public class SignalServiceMessageSender {
void onMessageSent();
void onSyncMessageSent();
}
public interface SenderKeyGroupEvents {
SenderKeyGroupEvents EMPTY = new SenderKeyGroupEvents() {
@Override
public void onSenderKeyShared() { }
@Override
public void onMessageEncrypted() { }
@Override
public void onMessageSent() { }
@Override
public void onSyncMessageSent() { }
};
void onSenderKeyShared();
void onMessageEncrypted();
void onMessageSent();
void onSyncMessageSent();
}
public interface LegacyGroupEvents {
LegacyGroupEvents EMPTY = new LegacyGroupEvents() {
@Override
public void onMessageSent() { }
@Override
public void onSyncMessageSent() { }
};
void onMessageSent();
void onSyncMessageSent();
}
}