Improve backup export perf by using better batching.

This commit is contained in:
Greyson Parrelli
2024-10-24 13:46:13 -04:00
parent ec736afde4
commit ebca386dcb
9 changed files with 252 additions and 142 deletions

View File

@@ -283,7 +283,10 @@ object BackupRepository {
try { try {
val dbSnapshot: SignalDatabase = createSignalDatabaseSnapshot(mainDbName) val dbSnapshot: SignalDatabase = createSignalDatabaseSnapshot(mainDbName)
eventTimer.emit("main-db-snapshot")
val signalStoreSnapshot: SignalStore = createSignalStoreSnapshot(keyValueDbName) val signalStoreSnapshot: SignalStore = createSignalStoreSnapshot(keyValueDbName)
eventTimer.emit("store-db-snapshot")
val exportState = ExportState(backupTime = currentTime, mediaBackupEnabled = SignalStore.backup.backsUpMedia) val exportState = ExportState(backupTime = currentTime, mediaBackupEnabled = SignalStore.backup.backsUpMedia)

View File

@@ -5,73 +5,113 @@
package org.thoughtcrime.securesms.backup.v2.database package org.thoughtcrime.securesms.backup.v2.database
import org.signal.core.util.logging.Log
import org.signal.core.util.select import org.signal.core.util.select
import org.thoughtcrime.securesms.backup.v2.ImportState import org.thoughtcrime.securesms.backup.v2.ImportState
import org.thoughtcrime.securesms.backup.v2.exporters.ChatItemArchiveExporter import org.thoughtcrime.securesms.backup.v2.exporters.ChatItemArchiveExporter
import org.thoughtcrime.securesms.backup.v2.importer.ChatItemArchiveImporter import org.thoughtcrime.securesms.backup.v2.importer.ChatItemArchiveImporter
import org.thoughtcrime.securesms.database.MessageTable import org.thoughtcrime.securesms.database.MessageTable
import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.database.SignalDatabase
import java.util.concurrent.TimeUnit
private val TAG = "MessageTableArchiveExtensions"
fun MessageTable.getMessagesForBackup(db: SignalDatabase, backupTime: Long, mediaBackupEnabled: Boolean): ChatItemArchiveExporter { fun MessageTable.getMessagesForBackup(db: SignalDatabase, backupTime: Long, mediaBackupEnabled: Boolean): ChatItemArchiveExporter {
// We create a temporary index on date_received to drastically speed up perf here. // We create a covering index for the query to drastically speed up perf here.
// Remember that we're working on a temporary snapshot of the database, so we can create an index and not worry about cleaning it up. // Remember that we're working on a temporary snapshot of the database, so we can create an index and not worry about cleaning it up.
val startTime = System.currentTimeMillis()
val dateReceivedIndex = "message_date_received" val dateReceivedIndex = "message_date_received"
writableDatabase.execSQL("CREATE INDEX $dateReceivedIndex ON ${MessageTable.TABLE_NAME} (${MessageTable.DATE_RECEIVED} ASC)") writableDatabase.execSQL(
"""CREATE INDEX $dateReceivedIndex ON ${MessageTable.TABLE_NAME} (
val cursor = readableDatabase ${MessageTable.DATE_RECEIVED} ASC,
.select( ${MessageTable.STORY_TYPE},
MessageTable.ID, ${MessageTable.ID},
MessageTable.DATE_SENT, ${MessageTable.DATE_SENT},
MessageTable.DATE_RECEIVED, ${MessageTable.DATE_SERVER},
MessageTable.DATE_SERVER, ${MessageTable.TYPE},
MessageTable.TYPE, ${MessageTable.THREAD_ID},
MessageTable.THREAD_ID, ${MessageTable.BODY},
MessageTable.BODY, ${MessageTable.MESSAGE_RANGES},
MessageTable.MESSAGE_RANGES, ${MessageTable.FROM_RECIPIENT_ID},
MessageTable.FROM_RECIPIENT_ID, ${MessageTable.TO_RECIPIENT_ID},
MessageTable.TO_RECIPIENT_ID, ${MessageTable.EXPIRES_IN},
MessageTable.EXPIRES_IN, ${MessageTable.EXPIRE_STARTED},
MessageTable.EXPIRE_STARTED, ${MessageTable.REMOTE_DELETED},
MessageTable.REMOTE_DELETED, ${MessageTable.UNIDENTIFIED},
MessageTable.UNIDENTIFIED, ${MessageTable.LINK_PREVIEWS},
MessageTable.LINK_PREVIEWS, ${MessageTable.SHARED_CONTACTS},
MessageTable.SHARED_CONTACTS, ${MessageTable.QUOTE_ID},
MessageTable.QUOTE_ID, ${MessageTable.QUOTE_AUTHOR},
MessageTable.QUOTE_AUTHOR, ${MessageTable.QUOTE_BODY},
MessageTable.QUOTE_BODY, ${MessageTable.QUOTE_MISSING},
MessageTable.QUOTE_MISSING, ${MessageTable.QUOTE_BODY_RANGES},
MessageTable.QUOTE_BODY_RANGES, ${MessageTable.QUOTE_TYPE},
MessageTable.QUOTE_TYPE, ${MessageTable.ORIGINAL_MESSAGE_ID},
MessageTable.ORIGINAL_MESSAGE_ID, ${MessageTable.LATEST_REVISION_ID},
MessageTable.LATEST_REVISION_ID, ${MessageTable.HAS_DELIVERY_RECEIPT},
MessageTable.HAS_DELIVERY_RECEIPT, ${MessageTable.HAS_READ_RECEIPT},
MessageTable.HAS_READ_RECEIPT, ${MessageTable.VIEWED_COLUMN},
MessageTable.VIEWED_COLUMN, ${MessageTable.RECEIPT_TIMESTAMP},
MessageTable.RECEIPT_TIMESTAMP, ${MessageTable.READ},
MessageTable.READ, ${MessageTable.NETWORK_FAILURES},
MessageTable.NETWORK_FAILURES, ${MessageTable.MISMATCHED_IDENTITIES},
MessageTable.MISMATCHED_IDENTITIES, ${MessageTable.TYPE},
MessageTable.TYPE, ${MessageTable.MESSAGE_EXTRAS},
MessageTable.MESSAGE_EXTRAS, ${MessageTable.VIEW_ONCE}
MessageTable.VIEW_ONCE
) )
.from("${MessageTable.TABLE_NAME} INDEXED BY $dateReceivedIndex") """.trimMargin()
.where( )
""" Log.d(TAG, "Creating index took ${System.currentTimeMillis() - startTime} ms")
(
${MessageTable.EXPIRE_STARTED} = 0
OR
(${MessageTable.EXPIRES_IN} > 0 AND (${MessageTable.EXPIRE_STARTED} + ${MessageTable.EXPIRES_IN}) > $backupTime + ${TimeUnit.DAYS.toMillis(1)})
)
AND ${MessageTable.STORY_TYPE} = 0
"""
)
.orderBy("${MessageTable.DATE_RECEIVED} ASC")
.run()
return ChatItemArchiveExporter(db, cursor, 100, mediaBackupEnabled) return ChatItemArchiveExporter(
db = db,
backupStartTime = backupTime,
batchSize = 10_000,
mediaArchiveEnabled = mediaBackupEnabled,
cursorGenerator = { lastSeenReceivedTime, count ->
readableDatabase
.select(
MessageTable.ID,
MessageTable.DATE_SENT,
MessageTable.DATE_RECEIVED,
MessageTable.DATE_SERVER,
MessageTable.TYPE,
MessageTable.THREAD_ID,
MessageTable.BODY,
MessageTable.MESSAGE_RANGES,
MessageTable.FROM_RECIPIENT_ID,
MessageTable.TO_RECIPIENT_ID,
MessageTable.EXPIRES_IN,
MessageTable.EXPIRE_STARTED,
MessageTable.REMOTE_DELETED,
MessageTable.UNIDENTIFIED,
MessageTable.LINK_PREVIEWS,
MessageTable.SHARED_CONTACTS,
MessageTable.QUOTE_ID,
MessageTable.QUOTE_AUTHOR,
MessageTable.QUOTE_BODY,
MessageTable.QUOTE_MISSING,
MessageTable.QUOTE_BODY_RANGES,
MessageTable.QUOTE_TYPE,
MessageTable.ORIGINAL_MESSAGE_ID,
MessageTable.LATEST_REVISION_ID,
MessageTable.HAS_DELIVERY_RECEIPT,
MessageTable.HAS_READ_RECEIPT,
MessageTable.VIEWED_COLUMN,
MessageTable.RECEIPT_TIMESTAMP,
MessageTable.READ,
MessageTable.NETWORK_FAILURES,
MessageTable.MISMATCHED_IDENTITIES,
MessageTable.TYPE,
MessageTable.MESSAGE_EXTRAS,
MessageTable.VIEW_ONCE
)
.from("${MessageTable.TABLE_NAME} INDEXED BY $dateReceivedIndex")
.where("${MessageTable.STORY_TYPE} = 0 AND ${MessageTable.DATE_RECEIVED} >= $lastSeenReceivedTime")
.limit(count)
.orderBy("${MessageTable.DATE_RECEIVED} ASC")
.run()
}
)
} }
fun MessageTable.createChatItemInserter(importState: ImportState): ChatItemArchiveImporter { fun MessageTable.createChatItemInserter(importState: ImportState): ChatItemArchiveImporter {

View File

@@ -12,6 +12,7 @@ import org.json.JSONException
import org.signal.core.util.Base64 import org.signal.core.util.Base64
import org.signal.core.util.EventTimer import org.signal.core.util.EventTimer
import org.signal.core.util.Hex import org.signal.core.util.Hex
import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.core.util.nullIfEmpty import org.signal.core.util.nullIfEmpty
import org.signal.core.util.orNull import org.signal.core.util.orNull
@@ -83,7 +84,11 @@ import java.io.IOException
import java.util.HashMap import java.util.HashMap
import java.util.LinkedList import java.util.LinkedList
import java.util.Queue import java.util.Queue
import java.util.concurrent.Callable
import java.util.concurrent.ExecutorService
import java.util.concurrent.Future
import kotlin.jvm.optionals.getOrNull import kotlin.jvm.optionals.getOrNull
import kotlin.time.Duration.Companion.days
import org.thoughtcrime.securesms.backup.v2.proto.BodyRange as BackupBodyRange import org.thoughtcrime.securesms.backup.v2.proto.BodyRange as BackupBodyRange
import org.thoughtcrime.securesms.backup.v2.proto.GiftBadge as BackupGiftBadge import org.thoughtcrime.securesms.backup.v2.proto.GiftBadge as BackupGiftBadge
@@ -98,9 +103,10 @@ private val TAG = Log.tag(ChatItemArchiveExporter::class.java)
*/ */
class ChatItemArchiveExporter( class ChatItemArchiveExporter(
private val db: SignalDatabase, private val db: SignalDatabase,
private val cursor: Cursor, private val backupStartTime: Long,
private val batchSize: Int, private val batchSize: Int,
private val mediaArchiveEnabled: Boolean private val mediaArchiveEnabled: Boolean,
private val cursorGenerator: (Long, Int) -> Cursor
) : Iterator<ChatItem?>, Closeable { ) : Iterator<ChatItem?>, Closeable {
private val eventTimer = EventTimer() private val eventTimer = EventTimer()
@@ -113,8 +119,12 @@ class ChatItemArchiveExporter(
private val revisionMap: HashMap<Long, ArrayList<ChatItem>> = HashMap() private val revisionMap: HashMap<Long, ArrayList<ChatItem>> = HashMap()
private var lastSeenReceivedTime = 0L
private var records: LinkedHashMap<Long, BackupMessageRecord> = readNextMessageRecordBatch(emptySet())
override fun hasNext(): Boolean { override fun hasNext(): Boolean {
return buffer.isNotEmpty() || (cursor.count > 0 && !cursor.isLast && !cursor.isAfterLast) return buffer.isNotEmpty() || records.isNotEmpty()
} }
override fun next(): ChatItem? { override fun next(): ChatItem? {
@@ -122,19 +132,8 @@ class ChatItemArchiveExporter(
return buffer.remove() return buffer.remove()
} }
val records: LinkedHashMap<Long, BackupMessageRecord> = LinkedHashMap(batchSize)
for (i in 0 until batchSize) {
if (cursor.moveToNext()) {
val record = cursor.toBackupMessageRecord()
records[record.id] = record
} else {
break
}
}
eventTimer.emit("messages")
val extraData = fetchExtraMessageData(db, records.keys) val extraData = fetchExtraMessageData(db, records.keys)
eventTimer.emit("extra-data")
for ((id, record) in records) { for ((id, record) in records) {
val builder = record.toBasicChatItemBuilder(extraData.groupReceiptsById[id]) val builder = record.toBasicChatItemBuilder(extraData.groupReceiptsById[id])
@@ -288,6 +287,13 @@ class ChatItemArchiveExporter(
previousEdits += builder.build() previousEdits += builder.build()
} }
} }
eventTimer.emit("transform")
val recordIds = HashSet(records.keys)
records.clear()
records = readNextMessageRecordBatch(recordIds)
eventTimer.emit("messages")
return if (buffer.isNotEmpty()) { return if (buffer.isNotEmpty()) {
buffer.remove() buffer.remove()
@@ -297,46 +303,45 @@ class ChatItemArchiveExporter(
} }
override fun close() { override fun close() {
cursor.close() Log.d(TAG, "[ChatItemArchiveExporter][batchSize = $batchSize] ${eventTimer.stop().summary}")
Log.w(TAG, "[ChatItemArchiveExporter] ${eventTimer.stop().summary}") }
private fun readNextMessageRecordBatch(pastIds: Set<Long>): LinkedHashMap<Long, BackupMessageRecord> {
return cursorGenerator(lastSeenReceivedTime, batchSize).use { cursor ->
val records: LinkedHashMap<Long, BackupMessageRecord> = LinkedHashMap(batchSize)
while (cursor.moveToNext()) {
cursor.toBackupMessageRecord(pastIds, backupStartTime)?.let { record ->
records[record.id] = record
lastSeenReceivedTime = record.dateReceived
}
}
records
}
} }
private fun fetchExtraMessageData(db: SignalDatabase, messageIds: Set<Long>): ExtraMessageData { private fun fetchExtraMessageData(db: SignalDatabase, messageIds: Set<Long>): ExtraMessageData {
// TODO [backup] This seems to be a wash val executor = SignalExecutors.BOUNDED
// val executor = SignalExecutors.BOUNDED
//
// val mentionsFuture = executor.submitTyped {
// db.mentionTable.getMentionsForMessages(messageIds)
// }
//
// val reactionsFuture = executor.submitTyped {
// db.reactionTable.getReactionsForMessages(messageIds)
// }
//
// val attachmentsFuture = executor.submitTyped {
// db.attachmentTable.getAttachmentsForMessages(messageIds)
// }
//
// val groupReceiptsFuture = executor.submitTyped {
// db.groupReceiptTable.getGroupReceiptInfoForMessages(messageIds)
// }
//
// val mentionsResult = mentionsFuture.get()
// val reactionsResult = reactionsFuture.get()
// val attachmentsResult = attachmentsFuture.get()
// val groupReceiptsResult = groupReceiptsFuture.get()
val mentionsResult = db.mentionTable.getMentionsForMessages(messageIds) val mentionsFuture = executor.submitTyped {
eventTimer.emit("mentions") db.mentionTable.getMentionsForMessages(messageIds)
}
val reactionsResult = db.reactionTable.getReactionsForMessages(messageIds) val reactionsFuture = executor.submitTyped {
eventTimer.emit("reactions") db.reactionTable.getReactionsForMessages(messageIds)
}
val attachmentsResult = db.attachmentTable.getAttachmentsForMessages(messageIds) val attachmentsFuture = executor.submitTyped {
eventTimer.emit("attachments") db.attachmentTable.getAttachmentsForMessages(messageIds)
}
val groupReceiptsResult = db.groupReceiptTable.getGroupReceiptInfoForMessages(messageIds) val groupReceiptsFuture = executor.submitTyped {
eventTimer.emit("receipts") db.groupReceiptTable.getGroupReceiptInfoForMessages(messageIds)
}
val mentionsResult = mentionsFuture.get()
val reactionsResult = reactionsFuture.get()
val attachmentsResult = attachmentsFuture.get()
val groupReceiptsResult = groupReceiptsFuture.get()
return ExtraMessageData( return ExtraMessageData(
mentionsById = mentionsResult, mentionsById = mentionsResult,
@@ -1104,13 +1109,25 @@ private fun String.e164ToLong(): Long? {
return fixed.toLongOrNull() return fixed.toLongOrNull()
} }
// private fun <T> ExecutorService.submitTyped(callable: Callable<T>): Future<T> { private fun <T> ExecutorService.submitTyped(callable: Callable<T>): Future<T> {
// return this.submit(callable) return this.submit(callable)
// } }
private fun Cursor.toBackupMessageRecord(pastIds: Set<Long>, backupStartTime: Long): BackupMessageRecord? {
val id = this.requireLong(MessageTable.ID)
if (pastIds.contains(id)) {
return null
}
val expiresIn = this.requireLong(MessageTable.EXPIRES_IN)
val expireStarted = this.requireLong(MessageTable.EXPIRE_STARTED)
if (expireStarted != 0L && expireStarted + expiresIn < backupStartTime + 1.days.inWholeMilliseconds) {
return null
}
private fun Cursor.toBackupMessageRecord(): BackupMessageRecord {
return BackupMessageRecord( return BackupMessageRecord(
id = this.requireLong(MessageTable.ID), id = id,
dateSent = this.requireLong(MessageTable.DATE_SENT), dateSent = this.requireLong(MessageTable.DATE_SENT),
dateReceived = this.requireLong(MessageTable.DATE_RECEIVED), dateReceived = this.requireLong(MessageTable.DATE_RECEIVED),
dateServer = this.requireLong(MessageTable.DATE_SERVER), dateServer = this.requireLong(MessageTable.DATE_SERVER),
@@ -1120,8 +1137,8 @@ private fun Cursor.toBackupMessageRecord(): BackupMessageRecord {
bodyRanges = this.requireBlob(MessageTable.MESSAGE_RANGES), bodyRanges = this.requireBlob(MessageTable.MESSAGE_RANGES),
fromRecipientId = this.requireLong(MessageTable.FROM_RECIPIENT_ID), fromRecipientId = this.requireLong(MessageTable.FROM_RECIPIENT_ID),
toRecipientId = this.requireLong(MessageTable.TO_RECIPIENT_ID), toRecipientId = this.requireLong(MessageTable.TO_RECIPIENT_ID),
expiresIn = this.requireLong(MessageTable.EXPIRES_IN), expiresIn = expiresIn,
expireStarted = this.requireLong(MessageTable.EXPIRE_STARTED), expireStarted = expireStarted,
remoteDeleted = this.requireBoolean(MessageTable.REMOTE_DELETED), remoteDeleted = this.requireBoolean(MessageTable.REMOTE_DELETED),
sealedSender = this.requireBoolean(MessageTable.UNIDENTIFIED), sealedSender = this.requireBoolean(MessageTable.UNIDENTIFIED),
linkPreview = this.requireString(MessageTable.LINK_PREVIEWS), linkPreview = this.requireString(MessageTable.LINK_PREVIEWS),

View File

@@ -434,7 +434,7 @@ class AttachmentTable(
return emptyMap() return emptyMap()
} }
val query = SqlUtil.buildSingleCollectionQuery(MESSAGE_ID, mmsIds) val query = SqlUtil.buildFastCollectionQuery(MESSAGE_ID, mmsIds)
return readableDatabase return readableDatabase
.select(*PROJECTION) .select(*PROJECTION)

View File

@@ -140,20 +140,17 @@ class GroupReceiptTable(context: Context?, databaseHelper: SignalDatabase?) : Da
val messageIdsToGroupReceipts: MutableMap<Long, MutableList<GroupReceiptInfo>> = mutableMapOf() val messageIdsToGroupReceipts: MutableMap<Long, MutableList<GroupReceiptInfo>> = mutableMapOf()
val args: List<Array<String>> = ids.map { SqlUtil.buildArgs(it) } val query = SqlUtil.buildFastCollectionQuery(MMS_ID, ids)
readableDatabase
SqlUtil.buildCustomCollectionQuery("$MMS_ID = ?", args).forEach { query -> .select()
readableDatabase .from(TABLE_NAME)
.select() .where(query.where, query.whereArgs)
.from(TABLE_NAME) .run()
.where(query.where, query.whereArgs) .forEach { cursor ->
.run() val messageId = cursor.requireLong(MMS_ID)
.forEach { cursor -> val receipts = messageIdsToGroupReceipts.getOrPut(messageId) { mutableListOf() }
val messageId = cursor.requireLong(MMS_ID) receipts += cursor.toGroupReceiptInfo()
val receipts = messageIdsToGroupReceipts.getOrPut(messageId) { mutableListOf() } }
receipts += cursor.toGroupReceiptInfo()
}
}
return messageIdsToGroupReceipts return messageIdsToGroupReceipts
} }

View File

@@ -2,6 +2,7 @@ package org.thoughtcrime.securesms.database
import android.content.Context import android.content.Context
import android.database.Cursor import android.database.Cursor
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete import org.signal.core.util.delete
import org.signal.core.util.deleteAll import org.signal.core.util.deleteAll
import org.signal.core.util.insertInto import org.signal.core.util.insertInto
@@ -77,12 +78,16 @@ class MentionTable(context: Context, databaseHelper: SignalDatabase) : DatabaseT
} }
fun getMentionsForMessages(messageIds: Collection<Long>): Map<Long, List<Mention>> { fun getMentionsForMessages(messageIds: Collection<Long>): Map<Long, List<Mention>> {
val ids = messageIds.joinToString(separator = ",") { it.toString() } if (messageIds.isEmpty()) {
return emptyMap()
}
val query = SqlUtil.buildFastCollectionQuery(MESSAGE_ID, messageIds)
return readableDatabase return readableDatabase
.select() .select()
.from("$TABLE_NAME INDEXED BY $MESSAGE_ID_INDEX") .from("$TABLE_NAME INDEXED BY $MESSAGE_ID_INDEX")
.where("$MESSAGE_ID IN ($ids)") .where(query.where, query.whereArgs)
.run() .run()
.use { cursor -> readMentions(cursor) } .use { cursor -> readMentions(cursor) }
} }

View File

@@ -6,6 +6,8 @@ import android.database.Cursor
import org.signal.core.util.CursorUtil import org.signal.core.util.CursorUtil
import org.signal.core.util.SqlUtil import org.signal.core.util.SqlUtil
import org.signal.core.util.delete import org.signal.core.util.delete
import org.signal.core.util.forEach
import org.signal.core.util.select
import org.signal.core.util.update import org.signal.core.util.update
import org.thoughtcrime.securesms.database.model.MessageId import org.thoughtcrime.securesms.database.model.MessageId
import org.thoughtcrime.securesms.database.model.ReactionRecord import org.thoughtcrime.securesms.database.model.ReactionRecord
@@ -77,25 +79,25 @@ class ReactionTable(context: Context, databaseHelper: SignalDatabase) : Database
val messageIdToReactions: MutableMap<Long, MutableList<ReactionRecord>> = mutableMapOf() val messageIdToReactions: MutableMap<Long, MutableList<ReactionRecord>> = mutableMapOf()
val args: List<Array<String>> = messageIds.map { SqlUtil.buildArgs(it) } val query = SqlUtil.buildFastCollectionQuery(MESSAGE_ID, messageIds)
readableDatabase
.select()
.from(TABLE_NAME)
.where(query.where, query.whereArgs)
.run()
.forEach { cursor ->
val reaction: ReactionRecord = readReaction(cursor)
val messageId = CursorUtil.requireLong(cursor, MESSAGE_ID)
for (query: SqlUtil.Query in SqlUtil.buildCustomCollectionQuery("$MESSAGE_ID = ?", args)) { var reactionsList: MutableList<ReactionRecord>? = messageIdToReactions[messageId]
readableDatabase.query(TABLE_NAME, null, query.where, query.whereArgs, null, null, null).use { cursor ->
while (cursor.moveToNext()) {
val reaction: ReactionRecord = readReaction(cursor)
val messageId = CursorUtil.requireLong(cursor, MESSAGE_ID)
var reactionsList: MutableList<ReactionRecord>? = messageIdToReactions[messageId] if (reactionsList == null) {
reactionsList = mutableListOf()
if (reactionsList == null) { messageIdToReactions[messageId] = reactionsList
reactionsList = mutableListOf()
messageIdToReactions[messageId] = reactionsList
}
reactionsList.add(reaction)
} }
reactionsList.add(reaction)
} }
}
return messageIdToReactions return messageIdToReactions
} }

View File

@@ -154,6 +154,17 @@ object SqlUtil {
}.toTypedArray() }.toTypedArray()
} }
@JvmStatic
fun buildArgs(objects: Collection<Any?>): Array<String> {
return objects.map {
when (it) {
null -> throw NullPointerException("Cannot have null arg!")
is DatabaseId -> it.serialize()
else -> it.toString()
}
}.toTypedArray()
}
@JvmStatic @JvmStatic
fun buildArgs(argument: Long): Array<String> { fun buildArgs(argument: Long): Array<String> {
return arrayOf(argument.toString()) return arrayOf(argument.toString())
@@ -290,6 +301,20 @@ object SqlUtil {
} }
} }
/**
* A convenient way of making queries that are _equivalent_ to `WHERE [column] IN (?, ?, ..., ?)`
* Under the hood, it uses JSON1 functions which can both be surprisingly faster than normal (?, ?, ?) lists, as well as removes the [MAX_QUERY_ARGS] limit.
* This means chunking isn't necessary for any practical collection length.
*/
@JvmStatic
fun buildFastCollectionQuery(
column: String,
values: Collection<Any?>
): Query {
require(!values.isEmpty()) { "Must have values!" }
return Query("$column IN (SELECT e.value FROM json_each(?) e)", arrayOf(jsonEncode(buildArgs(values))))
}
/** /**
* A convenient way of making queries in the form: WHERE [column] IN (?, ?, ..., ?) * A convenient way of making queries in the form: WHERE [column] IN (?, ?, ..., ?)
* *
@@ -453,6 +478,11 @@ object SqlUtil {
return null return null
} }
/** Simple encoding of a string array as a json array */
private fun jsonEncode(strings: Array<String>): String {
return strings.joinToString(prefix = "[", postfix = "]", separator = ",") { "\"$it\"" }
}
class Query(val where: String, val whereArgs: Array<String>) { class Query(val where: String, val whereArgs: Array<String>) {
infix fun and(other: Query): Query { infix fun and(other: Query): Query {
return if (where.isNotEmpty() && other.where.isNotEmpty()) { return if (where.isNotEmpty() && other.where.isNotEmpty()) {

View File

@@ -170,6 +170,22 @@ public final class SqlUtilTest {
assertTrue(results.isEmpty()); assertTrue(results.isEmpty());
} }
@Test
public void buildFastCollectionQuery_single() {
SqlUtil.Query updateQuery = SqlUtil.buildFastCollectionQuery("a", Arrays.asList(1));
assertEquals("a IN (SELECT e.value FROM json_each(?) e)", updateQuery.getWhere());
assertArrayEquals(new String[] { "[\"1\"]" }, updateQuery.getWhereArgs());
}
@Test
public void buildFastCollectionQuery_multiple() {
SqlUtil.Query updateQuery = SqlUtil.buildFastCollectionQuery("a", Arrays.asList(1, 2, 3));
assertEquals("a IN (SELECT e.value FROM json_each(?) e)", updateQuery.getWhere());
assertArrayEquals(new String[] { "[\"1\",\"2\",\"3\"]" }, updateQuery.getWhereArgs());
}
@Test @Test
public void buildCustomCollectionQuery_single_singleBatch() { public void buildCustomCollectionQuery_single_singleBatch() {
List<String[]> args = new ArrayList<>(); List<String[]> args = new ArrayList<>();