Add additional CDN reconciliations to BackupMediaSnapshotSyncJob.

Co-authored-by: Cody Henthorne <cody@signal.org>
This commit is contained in:
Greyson Parrelli
2025-04-25 11:03:26 -04:00
committed by Cody Henthorne
parent 85647f1258
commit f73d929feb
13 changed files with 434 additions and 45 deletions

View File

@@ -238,6 +238,23 @@ class AttachmentTableTest {
assertArrayEquals(digest, newDigest) assertArrayEquals(digest, newDigest)
} }
@Test
fun resetArchiveTransferStateByDigest_singleMatch() {
// Given an attachment with some digest
val blob = BlobProvider.getInstance().forData(byteArrayOf(1, 2, 3, 4, 5)).createForSingleSessionInMemory()
val attachment = createAttachment(1, blob, AttachmentTable.TransformProperties.empty())
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(-1L, listOf(attachment), emptyList()).values.first()
SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, AttachmentTableTestUtil.createUploadResult(attachmentId))
SignalDatabase.attachments.setArchiveTransferState(attachmentId, AttachmentTable.ArchiveTransferState.FINISHED)
// Reset the transfer state by digest
val digest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!!
SignalDatabase.attachments.resetArchiveTransferStateByDigest(digest)
// Verify it's been reset
assertThat(SignalDatabase.attachments.getAttachment(attachmentId)!!.archiveTransferState).isEqualTo(AttachmentTable.ArchiveTransferState.NONE)
}
private fun createAttachmentPointer(key: ByteArray, digest: ByteArray, size: Int): Attachment { private fun createAttachmentPointer(key: ByteArray, digest: ByteArray, size: Int): Attachment {
return PointerAttachment.forPointer( return PointerAttachment.forPointer(
pointer = Optional.of( pointer = Optional.of(

View File

@@ -0,0 +1,34 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database
import org.signal.core.util.Base64
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import kotlin.random.Random
object AttachmentTableTestUtil {
fun createUploadResult(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()): AttachmentUploadResult {
val databaseAttachment = SignalDatabase.attachments.getAttachment(attachmentId)!!
return AttachmentUploadResult(
remoteId = SignalServiceAttachmentRemoteId.V4("somewhere-${Random.nextLong()}"),
cdnNumber = Cdn.CDN_3.cdnNumber,
key = databaseAttachment.remoteKey?.let { Base64.decode(it) } ?: Util.getSecretBytes(64),
iv = databaseAttachment.remoteIv ?: Util.getSecretBytes(16),
digest = Random.nextBytes(32),
incrementalDigest = Random.nextBytes(16),
incrementalDigestChunkSize = 5,
uploadTimestamp = uploadTimestamp,
dataSize = databaseAttachment.size,
blurHash = databaseAttachment.blurHash?.hash
)
}
}

View File

@@ -28,8 +28,6 @@ import org.thoughtcrime.securesms.providers.BlobProvider
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.util.MediaUtil import org.thoughtcrime.securesms.util.MediaUtil
import org.thoughtcrime.securesms.util.Util import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import org.whispersystems.signalservice.api.push.ServiceId import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
import java.io.File import java.io.File
@@ -729,7 +727,7 @@ class AttachmentTableTest_deduping {
fun upload(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()) { fun upload(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()) {
SignalDatabase.attachments.createKeyIvIfNecessary(attachmentId) SignalDatabase.attachments.createKeyIvIfNecessary(attachmentId)
SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, createUploadResult(attachmentId, uploadTimestamp)) SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, AttachmentTableTestUtil.createUploadResult(attachmentId, uploadTimestamp))
val attachment = SignalDatabase.attachments.getAttachment(attachmentId)!! val attachment = SignalDatabase.attachments.getAttachment(attachmentId)!!
SignalDatabase.attachments.setArchiveCdn( SignalDatabase.attachments.setArchiveCdn(
@@ -875,23 +873,6 @@ class AttachmentTableTest_deduping {
private fun ByteArray.asMediaStream(): MediaStream { private fun ByteArray.asMediaStream(): MediaStream {
return MediaStream(this.inputStream(), MediaUtil.IMAGE_JPEG, 2, 2) return MediaStream(this.inputStream(), MediaUtil.IMAGE_JPEG, 2, 2)
} }
private fun createUploadResult(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()): AttachmentUploadResult {
val databaseAttachment = SignalDatabase.attachments.getAttachment(attachmentId)!!
return AttachmentUploadResult(
remoteId = SignalServiceAttachmentRemoteId.V4("somewhere-${Random.nextLong()}"),
cdnNumber = Cdn.CDN_3.cdnNumber,
key = databaseAttachment.remoteKey?.let { Base64.decode(it) } ?: Util.getSecretBytes(64),
iv = databaseAttachment.remoteIv ?: Util.getSecretBytes(16),
digest = Random.nextBytes(32),
incrementalDigest = Random.nextBytes(16),
incrementalDigestChunkSize = 5,
uploadTimestamp = uploadTimestamp,
dataSize = databaseAttachment.size,
blurHash = databaseAttachment.blurHash?.hash
)
}
} }
private fun test(content: TestContext.() -> Unit) { private fun test(content: TestContext.() -> Unit) {

View File

@@ -663,6 +663,20 @@ class AttachmentTable(
} }
} }
/**
* Sets the archive transfer state for the given attachment by digest.
*/
fun resetArchiveTransferStateByDigest(digest: ByteArray) {
writableDatabase
.update(TABLE_NAME)
.values(
ARCHIVE_TRANSFER_STATE to ArchiveTransferState.NONE.value,
ARCHIVE_CDN to 0
)
.where("$REMOTE_DIGEST = ?", digest)
.run()
}
/** /**
* Sets the archive transfer state for the given attachment and all other attachments that share the same data file. * Sets the archive transfer state for the given attachment and all other attachments that share the same data file.
*/ */

View File

@@ -6,17 +6,21 @@
package org.thoughtcrime.securesms.database package org.thoughtcrime.securesms.database
import android.content.Context import android.content.Context
import android.database.Cursor
import androidx.annotation.VisibleForTesting import androidx.annotation.VisibleForTesting
import androidx.core.content.contentValuesOf import androidx.core.content.contentValuesOf
import org.signal.core.util.SqlUtil import org.signal.core.util.SqlUtil
import org.signal.core.util.delete import org.signal.core.util.delete
import org.signal.core.util.readToList import org.signal.core.util.readToList
import org.signal.core.util.readToSet import org.signal.core.util.readToSet
import org.signal.core.util.requireBoolean
import org.signal.core.util.requireInt import org.signal.core.util.requireInt
import org.signal.core.util.requireNonNullBlob import org.signal.core.util.requireNonNullBlob
import org.signal.core.util.requireNonNullString import org.signal.core.util.requireNonNullString
import org.signal.core.util.select import org.signal.core.util.select
import org.signal.core.util.toInt import org.signal.core.util.toInt
import org.signal.core.util.update
import org.signal.core.util.updateAll
import org.thoughtcrime.securesms.backup.v2.ArchivedMediaObject import org.thoughtcrime.securesms.backup.v2.ArchivedMediaObject
/** /**
@@ -60,6 +64,11 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat
*/ */
const val IS_THUMBNAIL = "is_thumbnail" const val IS_THUMBNAIL = "is_thumbnail"
/**
* Timestamp when media was last seen on archive cdn. Can be reset to default.
*/
const val LAST_SEEN_ON_REMOTE_TIMESTAMP = "last_seen_on_remote_timestamp"
/** /**
* The remote digest for the media object. This is used to find matching attachments in the attachment table when necessary. * The remote digest for the media object. This is used to find matching attachments in the attachment table when necessary.
*/ */
@@ -73,7 +82,8 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat
$LAST_SYNC_TIME INTEGER DEFAULT 0, $LAST_SYNC_TIME INTEGER DEFAULT 0,
$PENDING_SYNC_TIME INTEGER, $PENDING_SYNC_TIME INTEGER,
$IS_THUMBNAIL INTEGER DEFAULT 0, $IS_THUMBNAIL INTEGER DEFAULT 0,
$REMOTE_DIGEST BLOB NOT NULL $REMOTE_DIGEST BLOB NOT NULL,
$LAST_SEEN_ON_REMOTE_TIMESTAMP INTEGER DEFAULT 0
) )
""".trimIndent() """.trimIndent()
} }
@@ -132,14 +142,17 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat
return emptySet() return emptySet()
} }
val query = SqlUtil.buildSingleCollectionQuery( val queries: List<SqlUtil.Query> = SqlUtil.buildCollectionQuery(
column = MEDIA_ID, column = MEDIA_ID,
values = objects.map { it.mediaId }, values = objects.map { it.mediaId },
collectionOperator = SqlUtil.CollectionOperator.NOT_IN, collectionOperator = SqlUtil.CollectionOperator.NOT_IN,
prefix = "$IS_THUMBNAIL = 0 AND " prefix = "$IS_THUMBNAIL = 0 AND "
) )
return readableDatabase val out: MutableSet<ArchivedMediaObject> = mutableSetOf()
for (query in queries) {
out += readableDatabase
.select(MEDIA_ID, CDN) .select(MEDIA_ID, CDN)
.from(TABLE_NAME) .from(TABLE_NAME)
.where(query.where, query.whereArgs) .where(query.where, query.whereArgs)
@@ -152,6 +165,9 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat
} }
} }
return out
}
/** /**
* Given a list of media objects, find the ones that we have no knowledge of in our local store. * Given a list of media objects, find the ones that we have no knowledge of in our local store.
*/ */
@@ -177,6 +193,47 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat
} }
} }
/**
* Indicate the time that the set of media objects were seen on the archive CDN. Can be used to reconcile our local state with the server state.
*/
fun markSeenOnRemote(mediaIdBatch: Collection<String>, time: Long) {
if (mediaIdBatch.isEmpty()) {
return
}
val query = SqlUtil.buildFastCollectionQuery(MEDIA_ID, mediaIdBatch)
writableDatabase
.update(TABLE_NAME)
.values(LAST_SEEN_ON_REMOTE_TIMESTAMP to time)
.where(query.where, query.whereArgs)
.run()
}
/**
* Get all media objects who were last seen on the remote server before the given time.
* This is used to find media objects that have not been seen on the CDN, even though they should be.
*
* The cursor contains rows that can be parsed into [MediaEntry] objects.
*/
fun getMediaObjectsLastSeenOnCdnBeforeTime(time: Long): Cursor {
return readableDatabase
.select(MEDIA_ID, CDN, REMOTE_DIGEST, IS_THUMBNAIL)
.from(TABLE_NAME)
.where("$LAST_SEEN_ON_REMOTE_TIMESTAMP < $time")
.run()
}
/**
* Resets the [LAST_SEEN_ON_REMOTE_TIMESTAMP] column back to zero. It's a good idea to do this after you have run a sync and used the value, as it can
* mitigate various issues that can arise from having an incorrect local clock.
*/
fun clearLastSeenOnRemote() {
writableDatabase
.updateAll(TABLE_NAME)
.values(LAST_SEEN_ON_REMOTE_TIMESTAMP to 0)
.run()
}
private fun writePendingMediaObjectsChunk(chunk: List<MediaEntry>, pendingSyncTime: Long) { private fun writePendingMediaObjectsChunk(chunk: List<MediaEntry>, pendingSyncTime: Long) {
val values = chunk.map { val values = chunk.map {
contentValuesOf( contentValuesOf(
@@ -213,10 +270,21 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat
val cdn: Int val cdn: Int
) )
private data class MediaEntry( class MediaEntry(
val mediaId: String, val mediaId: String,
val cdn: Int, val cdn: Int,
val digest: ByteArray, val digest: ByteArray,
val isThumbnail: Boolean val isThumbnail: Boolean
) {
companion object {
fun fromCursor(cursor: Cursor): MediaEntry {
return MediaEntry(
mediaId = cursor.requireNonNullString(MEDIA_ID),
cdn = cursor.requireInt(CDN),
digest = cursor.requireNonNullBlob(REMOTE_DIGEST),
isThumbnail = cursor.requireBoolean(IS_THUMBNAIL)
) )
}
}
}
} }

View File

@@ -128,6 +128,7 @@ import org.thoughtcrime.securesms.database.helpers.migration.V270_FixChatFolderC
import org.thoughtcrime.securesms.database.helpers.migration.V271_AddNotificationProfileIdColumn import org.thoughtcrime.securesms.database.helpers.migration.V271_AddNotificationProfileIdColumn
import org.thoughtcrime.securesms.database.helpers.migration.V272_UpdateUnreadCountIndices import org.thoughtcrime.securesms.database.helpers.migration.V272_UpdateUnreadCountIndices
import org.thoughtcrime.securesms.database.helpers.migration.V273_FixUnreadOriginalMessages import org.thoughtcrime.securesms.database.helpers.migration.V273_FixUnreadOriginalMessages
import org.thoughtcrime.securesms.database.helpers.migration.V274_BackupMediaSnapshotLastSeenOnRemote
import org.thoughtcrime.securesms.database.SQLiteDatabase as SignalSqliteDatabase import org.thoughtcrime.securesms.database.SQLiteDatabase as SignalSqliteDatabase
/** /**
@@ -261,10 +262,11 @@ object SignalDatabaseMigrations {
270 to V270_FixChatFolderColumnsForStorageSync, 270 to V270_FixChatFolderColumnsForStorageSync,
271 to V271_AddNotificationProfileIdColumn, 271 to V271_AddNotificationProfileIdColumn,
272 to V272_UpdateUnreadCountIndices, 272 to V272_UpdateUnreadCountIndices,
273 to V273_FixUnreadOriginalMessages 273 to V273_FixUnreadOriginalMessages,
274 to V274_BackupMediaSnapshotLastSeenOnRemote
) )
const val DATABASE_VERSION = 273 const val DATABASE_VERSION = 274
@JvmStatic @JvmStatic
fun migrate(context: Application, db: SignalSqliteDatabase, oldVersion: Int, newVersion: Int) { fun migrate(context: Application, db: SignalSqliteDatabase, oldVersion: Int, newVersion: Int) {

View File

@@ -0,0 +1,18 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database.helpers.migration
import android.app.Application
import org.thoughtcrime.securesms.database.SQLiteDatabase
/**
* Added a column to the backup media snapshot table to keep track of the last time we saw an object on the CDN.
*/
object V274_BackupMediaSnapshotLastSeenOnRemote : SignalDatabaseMigration {
override fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
db.execSQL("ALTER TABLE backup_media_snapshot ADD COLUMN last_seen_on_remote_timestamp INTEGER DEFAULT 0")
}
}

View File

@@ -14,8 +14,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore
import kotlin.time.Duration.Companion.days import kotlin.time.Duration.Companion.days
/** /**
* When run, this will find the next attachment that needs to be uploaded to the archive service and upload it. * When run, this will find all of the attachments that need to be uploaded to the archive tier and enqueue [UploadAttachmentToArchiveJob]s for them.
* It will enqueue a copy of itself if it thinks there is more work to be done, and that copy will continue the upload process.
*/ */
class ArchiveAttachmentBackfillJob private constructor(parameters: Parameters) : Job(parameters) { class ArchiveAttachmentBackfillJob private constructor(parameters: Parameters) : Job(parameters) {
companion object { companion object {

View File

@@ -5,10 +5,12 @@
package org.thoughtcrime.securesms.jobs package org.thoughtcrime.securesms.jobs
import org.signal.core.util.forEach
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.core.util.nullIfBlank import org.signal.core.util.nullIfBlank
import org.thoughtcrime.securesms.backup.v2.ArchivedMediaObject import org.thoughtcrime.securesms.backup.v2.ArchivedMediaObject
import org.thoughtcrime.securesms.backup.v2.BackupRepository import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.database.BackupMediaSnapshotTable
import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.AppDependencies import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.jobmanager.Job import org.thoughtcrime.securesms.jobmanager.Job
@@ -40,7 +42,8 @@ class BackupMediaSnapshotSyncJob private constructor(
const val KEY = "BackupMediaSnapshotSyncJob" const val KEY = "BackupMediaSnapshotSyncJob"
private const val REMOTE_DELETE_BATCH_SIZE = 500 private const val REMOTE_DELETE_BATCH_SIZE = 750
private const val CDN_PAGE_SIZE = 10_000
private val BACKUP_MEDIA_SYNC_INTERVAL = 7.days.inWholeMilliseconds private val BACKUP_MEDIA_SYNC_INTERVAL = 7.days.inWholeMilliseconds
fun enqueue(syncTime: Long) { fun enqueue(syncTime: Long) {
@@ -79,7 +82,9 @@ class BackupMediaSnapshotSyncJob private constructor(
return syncDataFromCdn() ?: Result.success() return syncDataFromCdn() ?: Result.success()
} }
override fun onFailure() = Unit override fun onFailure() {
SignalDatabase.backupMediaSnapshots.clearLastSeenOnRemote()
}
/** /**
* Looks through our local snapshot of what attachments we put in the last backup file, and uses that to delete any old attachments from the archive CDN * Looks through our local snapshot of what attachments we put in the last backup file, and uses that to delete any old attachments from the archive CDN
@@ -92,7 +97,7 @@ class BackupMediaSnapshotSyncJob private constructor(
deleteMediaObjectsFromCdn(mediaObjects)?.let { result -> return result } deleteMediaObjectsFromCdn(mediaObjects)?.let { result -> return result }
SignalDatabase.backupMediaSnapshots.deleteMediaObjects(mediaObjects) SignalDatabase.backupMediaSnapshots.deleteMediaObjects(mediaObjects)
mediaObjects = SignalDatabase.backupMediaSnapshots.getPageOfOldMediaObjects(syncTime, REMOTE_DELETE_BATCH_SIZE) mediaObjects = SignalDatabase.backupMediaSnapshots.getPageOfOldMediaObjects(syncTime, CDN_PAGE_SIZE)
} }
return null return null
@@ -135,6 +140,25 @@ class BackupMediaSnapshotSyncJob private constructor(
deleteMediaObjectsFromCdn(attachmentsToDelete)?.let { result -> return result } deleteMediaObjectsFromCdn(attachmentsToDelete)?.let { result -> return result }
} }
val entriesNeedingRepairCursor = SignalDatabase.backupMediaSnapshots.getMediaObjectsLastSeenOnCdnBeforeTime(syncTime)
val needRepairCount = entriesNeedingRepairCursor.count
if (needRepairCount > 0) {
Log.w(TAG, "Found $needRepairCount attachments that we thought were uploaded, but could not be found on the CDN. Clearing state and enqueuing uploads.")
entriesNeedingRepairCursor.forEach {
val entry = BackupMediaSnapshotTable.MediaEntry.fromCursor(it)
// TODO [backup] Re-enqueue thumbnail uploads if necessary
if (!entry.isThumbnail) {
SignalDatabase.attachments.resetArchiveTransferStateByDigest(entry.digest)
}
}
BackupMessagesJob.enqueue()
} else {
Log.d(TAG, "No attachments need to be repaired.")
}
SignalStore.backup.lastMediaSyncTime = System.currentTimeMillis() SignalStore.backup.lastMediaSyncTime = System.currentTimeMillis()
return null return null
@@ -152,6 +176,11 @@ class BackupMediaSnapshotSyncJob private constructor(
) )
} }
SignalDatabase.backupMediaSnapshots.markSeenOnRemote(
mediaIdBatch = mediaObjects.map { it.mediaId },
time = syncTime
)
val notFoundMediaObjects = SignalDatabase.backupMediaSnapshots.getMediaObjectsThatCantBeFound(mediaObjects) val notFoundMediaObjects = SignalDatabase.backupMediaSnapshots.getMediaObjectsThatCantBeFound(mediaObjects)
val remainingObjects = mediaObjects - notFoundMediaObjects val remainingObjects = mediaObjects - notFoundMediaObjects

View File

@@ -12,8 +12,10 @@ dependencies {
implementation(libs.androidx.sqlite) implementation(libs.androidx.sqlite)
implementation(libs.androidx.documentfile) implementation(libs.androidx.documentfile)
testImplementation(libs.androidx.sqlite.framework)
testImplementation(testLibs.junit.junit) testImplementation(testLibs.junit.junit)
testImplementation(testLibs.assertk)
testImplementation(testLibs.robolectric.robolectric) testImplementation(testLibs.robolectric.robolectric)
} }

View File

@@ -6,6 +6,7 @@ import android.database.sqlite.SQLiteDatabase
import androidx.core.content.contentValuesOf import androidx.core.content.contentValuesOf
import androidx.sqlite.db.SupportSQLiteDatabase import androidx.sqlite.db.SupportSQLiteDatabase
import androidx.sqlite.db.SupportSQLiteQueryBuilder import androidx.sqlite.db.SupportSQLiteQueryBuilder
import androidx.sqlite.db.SupportSQLiteStatement
import org.signal.core.util.SqlUtil.ForeignKeyViolation import org.signal.core.util.SqlUtil.ForeignKeyViolation
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import kotlin.time.Duration import kotlin.time.Duration
@@ -246,10 +247,34 @@ fun SupportSQLiteDatabase.deleteAll(tableName: String): Int {
return this.delete(tableName, null, arrayOfNulls<String>(0)) return this.delete(tableName, null, arrayOfNulls<String>(0))
} }
/**
* Begins an INSERT statement with a helpful builder pattern.
*/
fun SupportSQLiteDatabase.insertInto(tableName: String): InsertBuilderPart1 { fun SupportSQLiteDatabase.insertInto(tableName: String): InsertBuilderPart1 {
return InsertBuilderPart1(this, tableName) return InsertBuilderPart1(this, tableName)
} }
/**
* Bind an arbitrary value to an index. It will handle calling the correct bind method based on the class type.
* @param index The index you want to bind to. Important: Indexes start at 1, not 0.
*/
fun SupportSQLiteStatement.bindValue(index: Int, value: Any?) {
when (value) {
null -> this.bindNull(index)
is DatabaseId -> this.bindString(index, value.serialize())
is Boolean -> this.bindLong(index, value.toInt().toLong())
is ByteArray -> this.bindBlob(index, value)
is Number -> {
if (value.toLong() == value) {
this.bindLong(index, value.toLong())
} else {
this.bindDouble(index, value.toDouble())
}
}
else -> this.bindString(index, value.toString())
}
}
class SelectBuilderPart1( class SelectBuilderPart1(
private val db: SupportSQLiteDatabase, private val db: SupportSQLiteDatabase,
private val columns: Array<String> private val columns: Array<String>
@@ -422,7 +447,7 @@ class UpdateBuilderPart2(
) { ) {
fun where(where: String, vararg whereArgs: Any): UpdateBuilderPart3 { fun where(where: String, vararg whereArgs: Any): UpdateBuilderPart3 {
require(where.isNotBlank()) require(where.isNotBlank())
return UpdateBuilderPart3(db, tableName, values, where, SqlUtil.buildArgs(*whereArgs)) return UpdateBuilderPart3(db, tableName, values, where, whereArgs.toArgs())
} }
fun where(where: String, whereArgs: Array<String>): UpdateBuilderPart3 { fun where(where: String, whereArgs: Array<String>): UpdateBuilderPart3 {
@@ -436,11 +461,45 @@ class UpdateBuilderPart3(
private val tableName: String, private val tableName: String,
private val values: ContentValues, private val values: ContentValues,
private val where: String, private val where: String,
private val whereArgs: Array<String> private val whereArgs: Array<out Any?>
) { ) {
@JvmOverloads @JvmOverloads
fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_NONE): Int { fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_NONE): Int {
return db.update(tableName, conflictStrategy, values, where, whereArgs) val query = StringBuilder("UPDATE $tableName SET ")
val contentValuesKeys = values.keySet()
for ((index, column) in contentValuesKeys.withIndex()) {
query.append(column).append(" = ?")
if (index < contentValuesKeys.size - 1) {
query.append(", ")
}
}
query.append(" WHERE ").append(where)
val conflictString = when (conflictStrategy) {
SQLiteDatabase.CONFLICT_IGNORE -> " ON CONFLICT IGNORE"
SQLiteDatabase.CONFLICT_ABORT -> " ON CONFLICT ABORT"
SQLiteDatabase.CONFLICT_FAIL -> " ON CONFLICT FAIL"
SQLiteDatabase.CONFLICT_ROLLBACK -> " ON CONFLICT ROLLBACK"
SQLiteDatabase.CONFLICT_REPLACE -> " ON CONFLICT REPLACE"
else -> ""
}
query.append(conflictString)
val statement = db.compileStatement(query.toString())
var bindIndex = 1
for (key in contentValuesKeys) {
statement.bindValue(bindIndex, values.get(key))
bindIndex++
}
for (arg in whereArgs) {
statement.bindValue(bindIndex, arg)
bindIndex++
}
return statement.use { it.executeUpdateDelete() }
} }
} }
@@ -550,6 +609,20 @@ class InsertBuilderPart2(
} }
} }
/**
* Helper function to massage passed-in arguments into a better form to give to the database.
*/
private fun Array<out Any?>.toArgs(): Array<Any?> {
return this
.map {
when (it) {
is DatabaseId -> it.serialize()
else -> it
}
}
.toTypedArray()
}
data class ForeignKeyConstraint( data class ForeignKeyConstraint(
val table: String, val table: String,
val column: String, val column: String,

View File

@@ -0,0 +1,34 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import androidx.sqlite.db.SupportSQLiteDatabase
import androidx.sqlite.db.SupportSQLiteOpenHelper
import androidx.sqlite.db.framework.FrameworkSQLiteOpenHelperFactory
import androidx.test.core.app.ApplicationProvider
/**
* Helper to create an in-memory database used for testing SQLite stuff.
*/
object InMemorySqliteOpenHelper {
fun create(
onCreate: (db: SupportSQLiteDatabase) -> Unit,
onUpgrade: (db: SupportSQLiteDatabase, oldVersion: Int, newVersion: Int) -> Unit = { _, _, _ -> }
): SupportSQLiteOpenHelper {
val configuration = SupportSQLiteOpenHelper.Configuration(
context = ApplicationProvider.getApplicationContext(),
name = "test",
callback = object : SupportSQLiteOpenHelper.Callback(1) {
override fun onCreate(db: SupportSQLiteDatabase) = onCreate(db)
override fun onUpgrade(db: SupportSQLiteDatabase, oldVersion: Int, newVersion: Int) = onUpgrade(db, oldVersion, newVersion)
},
useNoBackupDirectory = false,
allowDataLossOnRecovery = true
)
return FrameworkSQLiteOpenHelperFactory().create(configuration)
}
}

View File

@@ -0,0 +1,118 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.Application
import androidx.sqlite.db.SupportSQLiteOpenHelper
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNotNull
import org.junit.Assert.assertArrayEquals
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.annotation.Config
@RunWith(RobolectricTestRunner::class)
@Config(manifest = Config.NONE, application = Application::class)
class SQLiteDatabaseExtensionsTest {
lateinit var db: SupportSQLiteOpenHelper
companion object {
const val TABLE_NAME = "test"
const val ID = "_id"
const val STRING_COLUMN = "string_column"
const val LONG_COLUMN = "long_column"
const val DOUBLE_COLUMN = "double_column"
const val BLOB_COLUMN = "blob_column"
}
@Before
fun setup() {
db = InMemorySqliteOpenHelper.create(
onCreate = { db ->
db.execSQL("CREATE TABLE $TABLE_NAME ($ID INTEGER PRIMARY KEY AUTOINCREMENT, $STRING_COLUMN TEXT, $LONG_COLUMN INTEGER, $DOUBLE_COLUMN DOUBLE, $BLOB_COLUMN BLOB)")
}
)
db.writableDatabase.insertInto(TABLE_NAME)
.values(
STRING_COLUMN to "asdf",
LONG_COLUMN to 1,
DOUBLE_COLUMN to 0.5f,
BLOB_COLUMN to byteArrayOf(1, 2, 3)
)
.run()
}
@Test
fun `update - content values work`() {
val updateCount: Int = db.writableDatabase
.update("test")
.values(
STRING_COLUMN to "asdf2",
LONG_COLUMN to 2,
DOUBLE_COLUMN to 1.5f,
BLOB_COLUMN to byteArrayOf(4, 5, 6)
)
.where("$ID = ?", 1)
.run()
val record = readRecord(1)
assertThat(updateCount).isEqualTo(1)
assertThat(record).isNotNull()
assertThat(record!!.id).isEqualTo(1)
assertThat(record.stringColumn).isEqualTo("asdf2")
assertThat(record.longColumn).isEqualTo(2)
assertThat(record.doubleColumn).isEqualTo(1.5f)
assertArrayEquals(record.blobColumn, byteArrayOf(4, 5, 6))
}
@Test
fun `update - querying by blob works`() {
val updateCount: Int = db.writableDatabase
.update("test")
.values(
STRING_COLUMN to "asdf2"
)
.where("$BLOB_COLUMN = ?", byteArrayOf(1, 2, 3))
.run()
val record = readRecord(1)
assertThat(updateCount).isEqualTo(1)
assertThat(record).isNotNull()
assertThat(record!!.stringColumn).isEqualTo("asdf2")
}
private fun readRecord(id: Long): TestRecord? {
return db.readableDatabase
.select()
.from(TABLE_NAME)
.where("$ID = ?", id)
.run()
.readToSingleObject {
TestRecord(
id = it.requireLong(ID),
stringColumn = it.requireString(STRING_COLUMN),
longColumn = it.requireLong(LONG_COLUMN),
doubleColumn = it.requireFloat(DOUBLE_COLUMN),
blobColumn = it.requireBlob(BLOB_COLUMN)
)
}
}
class TestRecord(
val id: Long,
val stringColumn: String?,
val longColumn: Long,
val doubleColumn: Float,
val blobColumn: ByteArray?
)
}