Improve backup export perf by using better batching.

This commit is contained in:
Greyson Parrelli
2024-10-24 13:46:13 -04:00
parent ec736afde4
commit ebca386dcb
9 changed files with 252 additions and 142 deletions

View File

@@ -283,7 +283,10 @@ object BackupRepository {
try {
val dbSnapshot: SignalDatabase = createSignalDatabaseSnapshot(mainDbName)
eventTimer.emit("main-db-snapshot")
val signalStoreSnapshot: SignalStore = createSignalStoreSnapshot(keyValueDbName)
eventTimer.emit("store-db-snapshot")
val exportState = ExportState(backupTime = currentTime, mediaBackupEnabled = SignalStore.backup.backsUpMedia)

View File

@@ -5,22 +5,70 @@
package org.thoughtcrime.securesms.backup.v2.database
import org.signal.core.util.logging.Log
import org.signal.core.util.select
import org.thoughtcrime.securesms.backup.v2.ImportState
import org.thoughtcrime.securesms.backup.v2.exporters.ChatItemArchiveExporter
import org.thoughtcrime.securesms.backup.v2.importer.ChatItemArchiveImporter
import org.thoughtcrime.securesms.database.MessageTable
import org.thoughtcrime.securesms.database.SignalDatabase
import java.util.concurrent.TimeUnit
private val TAG = "MessageTableArchiveExtensions"
fun MessageTable.getMessagesForBackup(db: SignalDatabase, backupTime: Long, mediaBackupEnabled: Boolean): ChatItemArchiveExporter {
// We create a temporary index on date_received to drastically speed up perf here.
// We create a covering index for the query to drastically speed up perf here.
// Remember that we're working on a temporary snapshot of the database, so we can create an index and not worry about cleaning it up.
val startTime = System.currentTimeMillis()
val dateReceivedIndex = "message_date_received"
writableDatabase.execSQL("CREATE INDEX $dateReceivedIndex ON ${MessageTable.TABLE_NAME} (${MessageTable.DATE_RECEIVED} ASC)")
writableDatabase.execSQL(
"""CREATE INDEX $dateReceivedIndex ON ${MessageTable.TABLE_NAME} (
${MessageTable.DATE_RECEIVED} ASC,
${MessageTable.STORY_TYPE},
${MessageTable.ID},
${MessageTable.DATE_SENT},
${MessageTable.DATE_SERVER},
${MessageTable.TYPE},
${MessageTable.THREAD_ID},
${MessageTable.BODY},
${MessageTable.MESSAGE_RANGES},
${MessageTable.FROM_RECIPIENT_ID},
${MessageTable.TO_RECIPIENT_ID},
${MessageTable.EXPIRES_IN},
${MessageTable.EXPIRE_STARTED},
${MessageTable.REMOTE_DELETED},
${MessageTable.UNIDENTIFIED},
${MessageTable.LINK_PREVIEWS},
${MessageTable.SHARED_CONTACTS},
${MessageTable.QUOTE_ID},
${MessageTable.QUOTE_AUTHOR},
${MessageTable.QUOTE_BODY},
${MessageTable.QUOTE_MISSING},
${MessageTable.QUOTE_BODY_RANGES},
${MessageTable.QUOTE_TYPE},
${MessageTable.ORIGINAL_MESSAGE_ID},
${MessageTable.LATEST_REVISION_ID},
${MessageTable.HAS_DELIVERY_RECEIPT},
${MessageTable.HAS_READ_RECEIPT},
${MessageTable.VIEWED_COLUMN},
${MessageTable.RECEIPT_TIMESTAMP},
${MessageTable.READ},
${MessageTable.NETWORK_FAILURES},
${MessageTable.MISMATCHED_IDENTITIES},
${MessageTable.TYPE},
${MessageTable.MESSAGE_EXTRAS},
${MessageTable.VIEW_ONCE}
)
""".trimMargin()
)
Log.d(TAG, "Creating index took ${System.currentTimeMillis() - startTime} ms")
val cursor = readableDatabase
return ChatItemArchiveExporter(
db = db,
backupStartTime = backupTime,
batchSize = 10_000,
mediaArchiveEnabled = mediaBackupEnabled,
cursorGenerator = { lastSeenReceivedTime, count ->
readableDatabase
.select(
MessageTable.ID,
MessageTable.DATE_SENT,
@@ -58,20 +106,12 @@ fun MessageTable.getMessagesForBackup(db: SignalDatabase, backupTime: Long, medi
MessageTable.VIEW_ONCE
)
.from("${MessageTable.TABLE_NAME} INDEXED BY $dateReceivedIndex")
.where(
"""
(
${MessageTable.EXPIRE_STARTED} = 0
OR
(${MessageTable.EXPIRES_IN} > 0 AND (${MessageTable.EXPIRE_STARTED} + ${MessageTable.EXPIRES_IN}) > $backupTime + ${TimeUnit.DAYS.toMillis(1)})
)
AND ${MessageTable.STORY_TYPE} = 0
"""
)
.where("${MessageTable.STORY_TYPE} = 0 AND ${MessageTable.DATE_RECEIVED} >= $lastSeenReceivedTime")
.limit(count)
.orderBy("${MessageTable.DATE_RECEIVED} ASC")
.run()
return ChatItemArchiveExporter(db, cursor, 100, mediaBackupEnabled)
}
)
}
fun MessageTable.createChatItemInserter(importState: ImportState): ChatItemArchiveImporter {

View File

@@ -12,6 +12,7 @@ import org.json.JSONException
import org.signal.core.util.Base64
import org.signal.core.util.EventTimer
import org.signal.core.util.Hex
import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.logging.Log
import org.signal.core.util.nullIfEmpty
import org.signal.core.util.orNull
@@ -83,7 +84,11 @@ import java.io.IOException
import java.util.HashMap
import java.util.LinkedList
import java.util.Queue
import java.util.concurrent.Callable
import java.util.concurrent.ExecutorService
import java.util.concurrent.Future
import kotlin.jvm.optionals.getOrNull
import kotlin.time.Duration.Companion.days
import org.thoughtcrime.securesms.backup.v2.proto.BodyRange as BackupBodyRange
import org.thoughtcrime.securesms.backup.v2.proto.GiftBadge as BackupGiftBadge
@@ -98,9 +103,10 @@ private val TAG = Log.tag(ChatItemArchiveExporter::class.java)
*/
class ChatItemArchiveExporter(
private val db: SignalDatabase,
private val cursor: Cursor,
private val backupStartTime: Long,
private val batchSize: Int,
private val mediaArchiveEnabled: Boolean
private val mediaArchiveEnabled: Boolean,
private val cursorGenerator: (Long, Int) -> Cursor
) : Iterator<ChatItem?>, Closeable {
private val eventTimer = EventTimer()
@@ -113,8 +119,12 @@ class ChatItemArchiveExporter(
private val revisionMap: HashMap<Long, ArrayList<ChatItem>> = HashMap()
private var lastSeenReceivedTime = 0L
private var records: LinkedHashMap<Long, BackupMessageRecord> = readNextMessageRecordBatch(emptySet())
override fun hasNext(): Boolean {
return buffer.isNotEmpty() || (cursor.count > 0 && !cursor.isLast && !cursor.isAfterLast)
return buffer.isNotEmpty() || records.isNotEmpty()
}
override fun next(): ChatItem? {
@@ -122,19 +132,8 @@ class ChatItemArchiveExporter(
return buffer.remove()
}
val records: LinkedHashMap<Long, BackupMessageRecord> = LinkedHashMap(batchSize)
for (i in 0 until batchSize) {
if (cursor.moveToNext()) {
val record = cursor.toBackupMessageRecord()
records[record.id] = record
} else {
break
}
}
eventTimer.emit("messages")
val extraData = fetchExtraMessageData(db, records.keys)
eventTimer.emit("extra-data")
for ((id, record) in records) {
val builder = record.toBasicChatItemBuilder(extraData.groupReceiptsById[id])
@@ -288,6 +287,13 @@ class ChatItemArchiveExporter(
previousEdits += builder.build()
}
}
eventTimer.emit("transform")
val recordIds = HashSet(records.keys)
records.clear()
records = readNextMessageRecordBatch(recordIds)
eventTimer.emit("messages")
return if (buffer.isNotEmpty()) {
buffer.remove()
@@ -297,46 +303,45 @@ class ChatItemArchiveExporter(
}
override fun close() {
cursor.close()
Log.w(TAG, "[ChatItemArchiveExporter] ${eventTimer.stop().summary}")
Log.d(TAG, "[ChatItemArchiveExporter][batchSize = $batchSize] ${eventTimer.stop().summary}")
}
private fun readNextMessageRecordBatch(pastIds: Set<Long>): LinkedHashMap<Long, BackupMessageRecord> {
return cursorGenerator(lastSeenReceivedTime, batchSize).use { cursor ->
val records: LinkedHashMap<Long, BackupMessageRecord> = LinkedHashMap(batchSize)
while (cursor.moveToNext()) {
cursor.toBackupMessageRecord(pastIds, backupStartTime)?.let { record ->
records[record.id] = record
lastSeenReceivedTime = record.dateReceived
}
}
records
}
}
private fun fetchExtraMessageData(db: SignalDatabase, messageIds: Set<Long>): ExtraMessageData {
// TODO [backup] This seems to be a wash
// val executor = SignalExecutors.BOUNDED
//
// val mentionsFuture = executor.submitTyped {
// db.mentionTable.getMentionsForMessages(messageIds)
// }
//
// val reactionsFuture = executor.submitTyped {
// db.reactionTable.getReactionsForMessages(messageIds)
// }
//
// val attachmentsFuture = executor.submitTyped {
// db.attachmentTable.getAttachmentsForMessages(messageIds)
// }
//
// val groupReceiptsFuture = executor.submitTyped {
// db.groupReceiptTable.getGroupReceiptInfoForMessages(messageIds)
// }
//
// val mentionsResult = mentionsFuture.get()
// val reactionsResult = reactionsFuture.get()
// val attachmentsResult = attachmentsFuture.get()
// val groupReceiptsResult = groupReceiptsFuture.get()
val executor = SignalExecutors.BOUNDED
val mentionsResult = db.mentionTable.getMentionsForMessages(messageIds)
eventTimer.emit("mentions")
val mentionsFuture = executor.submitTyped {
db.mentionTable.getMentionsForMessages(messageIds)
}
val reactionsResult = db.reactionTable.getReactionsForMessages(messageIds)
eventTimer.emit("reactions")
val reactionsFuture = executor.submitTyped {
db.reactionTable.getReactionsForMessages(messageIds)
}
val attachmentsResult = db.attachmentTable.getAttachmentsForMessages(messageIds)
eventTimer.emit("attachments")
val attachmentsFuture = executor.submitTyped {
db.attachmentTable.getAttachmentsForMessages(messageIds)
}
val groupReceiptsResult = db.groupReceiptTable.getGroupReceiptInfoForMessages(messageIds)
eventTimer.emit("receipts")
val groupReceiptsFuture = executor.submitTyped {
db.groupReceiptTable.getGroupReceiptInfoForMessages(messageIds)
}
val mentionsResult = mentionsFuture.get()
val reactionsResult = reactionsFuture.get()
val attachmentsResult = attachmentsFuture.get()
val groupReceiptsResult = groupReceiptsFuture.get()
return ExtraMessageData(
mentionsById = mentionsResult,
@@ -1104,13 +1109,25 @@ private fun String.e164ToLong(): Long? {
return fixed.toLongOrNull()
}
// private fun <T> ExecutorService.submitTyped(callable: Callable<T>): Future<T> {
// return this.submit(callable)
// }
private fun <T> ExecutorService.submitTyped(callable: Callable<T>): Future<T> {
return this.submit(callable)
}
private fun Cursor.toBackupMessageRecord(pastIds: Set<Long>, backupStartTime: Long): BackupMessageRecord? {
val id = this.requireLong(MessageTable.ID)
if (pastIds.contains(id)) {
return null
}
val expiresIn = this.requireLong(MessageTable.EXPIRES_IN)
val expireStarted = this.requireLong(MessageTable.EXPIRE_STARTED)
if (expireStarted != 0L && expireStarted + expiresIn < backupStartTime + 1.days.inWholeMilliseconds) {
return null
}
private fun Cursor.toBackupMessageRecord(): BackupMessageRecord {
return BackupMessageRecord(
id = this.requireLong(MessageTable.ID),
id = id,
dateSent = this.requireLong(MessageTable.DATE_SENT),
dateReceived = this.requireLong(MessageTable.DATE_RECEIVED),
dateServer = this.requireLong(MessageTable.DATE_SERVER),
@@ -1120,8 +1137,8 @@ private fun Cursor.toBackupMessageRecord(): BackupMessageRecord {
bodyRanges = this.requireBlob(MessageTable.MESSAGE_RANGES),
fromRecipientId = this.requireLong(MessageTable.FROM_RECIPIENT_ID),
toRecipientId = this.requireLong(MessageTable.TO_RECIPIENT_ID),
expiresIn = this.requireLong(MessageTable.EXPIRES_IN),
expireStarted = this.requireLong(MessageTable.EXPIRE_STARTED),
expiresIn = expiresIn,
expireStarted = expireStarted,
remoteDeleted = this.requireBoolean(MessageTable.REMOTE_DELETED),
sealedSender = this.requireBoolean(MessageTable.UNIDENTIFIED),
linkPreview = this.requireString(MessageTable.LINK_PREVIEWS),

View File

@@ -434,7 +434,7 @@ class AttachmentTable(
return emptyMap()
}
val query = SqlUtil.buildSingleCollectionQuery(MESSAGE_ID, mmsIds)
val query = SqlUtil.buildFastCollectionQuery(MESSAGE_ID, mmsIds)
return readableDatabase
.select(*PROJECTION)

View File

@@ -140,9 +140,7 @@ class GroupReceiptTable(context: Context?, databaseHelper: SignalDatabase?) : Da
val messageIdsToGroupReceipts: MutableMap<Long, MutableList<GroupReceiptInfo>> = mutableMapOf()
val args: List<Array<String>> = ids.map { SqlUtil.buildArgs(it) }
SqlUtil.buildCustomCollectionQuery("$MMS_ID = ?", args).forEach { query ->
val query = SqlUtil.buildFastCollectionQuery(MMS_ID, ids)
readableDatabase
.select()
.from(TABLE_NAME)
@@ -153,7 +151,6 @@ class GroupReceiptTable(context: Context?, databaseHelper: SignalDatabase?) : Da
val receipts = messageIdsToGroupReceipts.getOrPut(messageId) { mutableListOf() }
receipts += cursor.toGroupReceiptInfo()
}
}
return messageIdsToGroupReceipts
}

View File

@@ -2,6 +2,7 @@ package org.thoughtcrime.securesms.database
import android.content.Context
import android.database.Cursor
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete
import org.signal.core.util.deleteAll
import org.signal.core.util.insertInto
@@ -77,12 +78,16 @@ class MentionTable(context: Context, databaseHelper: SignalDatabase) : DatabaseT
}
fun getMentionsForMessages(messageIds: Collection<Long>): Map<Long, List<Mention>> {
val ids = messageIds.joinToString(separator = ",") { it.toString() }
if (messageIds.isEmpty()) {
return emptyMap()
}
val query = SqlUtil.buildFastCollectionQuery(MESSAGE_ID, messageIds)
return readableDatabase
.select()
.from("$TABLE_NAME INDEXED BY $MESSAGE_ID_INDEX")
.where("$MESSAGE_ID IN ($ids)")
.where(query.where, query.whereArgs)
.run()
.use { cursor -> readMentions(cursor) }
}

View File

@@ -6,6 +6,8 @@ import android.database.Cursor
import org.signal.core.util.CursorUtil
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete
import org.signal.core.util.forEach
import org.signal.core.util.select
import org.signal.core.util.update
import org.thoughtcrime.securesms.database.model.MessageId
import org.thoughtcrime.securesms.database.model.ReactionRecord
@@ -77,11 +79,13 @@ class ReactionTable(context: Context, databaseHelper: SignalDatabase) : Database
val messageIdToReactions: MutableMap<Long, MutableList<ReactionRecord>> = mutableMapOf()
val args: List<Array<String>> = messageIds.map { SqlUtil.buildArgs(it) }
for (query: SqlUtil.Query in SqlUtil.buildCustomCollectionQuery("$MESSAGE_ID = ?", args)) {
readableDatabase.query(TABLE_NAME, null, query.where, query.whereArgs, null, null, null).use { cursor ->
while (cursor.moveToNext()) {
val query = SqlUtil.buildFastCollectionQuery(MESSAGE_ID, messageIds)
readableDatabase
.select()
.from(TABLE_NAME)
.where(query.where, query.whereArgs)
.run()
.forEach { cursor ->
val reaction: ReactionRecord = readReaction(cursor)
val messageId = CursorUtil.requireLong(cursor, MESSAGE_ID)
@@ -94,8 +98,6 @@ class ReactionTable(context: Context, databaseHelper: SignalDatabase) : Database
reactionsList.add(reaction)
}
}
}
return messageIdToReactions
}

View File

@@ -154,6 +154,17 @@ object SqlUtil {
}.toTypedArray()
}
@JvmStatic
fun buildArgs(objects: Collection<Any?>): Array<String> {
return objects.map {
when (it) {
null -> throw NullPointerException("Cannot have null arg!")
is DatabaseId -> it.serialize()
else -> it.toString()
}
}.toTypedArray()
}
@JvmStatic
fun buildArgs(argument: Long): Array<String> {
return arrayOf(argument.toString())
@@ -290,6 +301,20 @@ object SqlUtil {
}
}
/**
* A convenient way of making queries that are _equivalent_ to `WHERE [column] IN (?, ?, ..., ?)`
* Under the hood, it uses JSON1 functions which can both be surprisingly faster than normal (?, ?, ?) lists, as well as removes the [MAX_QUERY_ARGS] limit.
* This means chunking isn't necessary for any practical collection length.
*/
@JvmStatic
fun buildFastCollectionQuery(
column: String,
values: Collection<Any?>
): Query {
require(!values.isEmpty()) { "Must have values!" }
return Query("$column IN (SELECT e.value FROM json_each(?) e)", arrayOf(jsonEncode(buildArgs(values))))
}
/**
* A convenient way of making queries in the form: WHERE [column] IN (?, ?, ..., ?)
*
@@ -453,6 +478,11 @@ object SqlUtil {
return null
}
/** Simple encoding of a string array as a json array */
private fun jsonEncode(strings: Array<String>): String {
return strings.joinToString(prefix = "[", postfix = "]", separator = ",") { "\"$it\"" }
}
class Query(val where: String, val whereArgs: Array<String>) {
infix fun and(other: Query): Query {
return if (where.isNotEmpty() && other.where.isNotEmpty()) {

View File

@@ -170,6 +170,22 @@ public final class SqlUtilTest {
assertTrue(results.isEmpty());
}
@Test
public void buildFastCollectionQuery_single() {
SqlUtil.Query updateQuery = SqlUtil.buildFastCollectionQuery("a", Arrays.asList(1));
assertEquals("a IN (SELECT e.value FROM json_each(?) e)", updateQuery.getWhere());
assertArrayEquals(new String[] { "[\"1\"]" }, updateQuery.getWhereArgs());
}
@Test
public void buildFastCollectionQuery_multiple() {
SqlUtil.Query updateQuery = SqlUtil.buildFastCollectionQuery("a", Arrays.asList(1, 2, 3));
assertEquals("a IN (SELECT e.value FROM json_each(?) e)", updateQuery.getWhere());
assertArrayEquals(new String[] { "[\"1\",\"2\",\"3\"]" }, updateQuery.getWhereArgs());
}
@Test
public void buildCustomCollectionQuery_single_singleBatch() {
List<String[]> args = new ArrayList<>();