diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt index a987b00727..fbf8006e23 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt @@ -238,6 +238,23 @@ class AttachmentTableTest { assertArrayEquals(digest, newDigest) } + @Test + fun resetArchiveTransferStateByDigest_singleMatch() { + // Given an attachment with some digest + val blob = BlobProvider.getInstance().forData(byteArrayOf(1, 2, 3, 4, 5)).createForSingleSessionInMemory() + val attachment = createAttachment(1, blob, AttachmentTable.TransformProperties.empty()) + val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(-1L, listOf(attachment), emptyList()).values.first() + SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, AttachmentTableTestUtil.createUploadResult(attachmentId)) + SignalDatabase.attachments.setArchiveTransferState(attachmentId, AttachmentTable.ArchiveTransferState.FINISHED) + + // Reset the transfer state by digest + val digest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!! + SignalDatabase.attachments.resetArchiveTransferStateByDigest(digest) + + // Verify it's been reset + assertThat(SignalDatabase.attachments.getAttachment(attachmentId)!!.archiveTransferState).isEqualTo(AttachmentTable.ArchiveTransferState.NONE) + } + private fun createAttachmentPointer(key: ByteArray, digest: ByteArray, size: Int): Attachment { return PointerAttachment.forPointer( pointer = Optional.of( diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTestUtil.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTestUtil.kt new file mode 100644 index 0000000000..6b4d651eee --- /dev/null +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTestUtil.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.database + +import org.signal.core.util.Base64 +import org.thoughtcrime.securesms.attachments.AttachmentId +import org.thoughtcrime.securesms.attachments.Cdn +import org.thoughtcrime.securesms.util.Util +import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult +import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId +import kotlin.random.Random + +object AttachmentTableTestUtil { + + fun createUploadResult(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()): AttachmentUploadResult { + val databaseAttachment = SignalDatabase.attachments.getAttachment(attachmentId)!! + + return AttachmentUploadResult( + remoteId = SignalServiceAttachmentRemoteId.V4("somewhere-${Random.nextLong()}"), + cdnNumber = Cdn.CDN_3.cdnNumber, + key = databaseAttachment.remoteKey?.let { Base64.decode(it) } ?: Util.getSecretBytes(64), + iv = databaseAttachment.remoteIv ?: Util.getSecretBytes(16), + digest = Random.nextBytes(32), + incrementalDigest = Random.nextBytes(16), + incrementalDigestChunkSize = 5, + uploadTimestamp = uploadTimestamp, + dataSize = databaseAttachment.size, + blurHash = databaseAttachment.blurHash?.hash + ) + } +} diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest_deduping.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest_deduping.kt index 0c111cf5ee..2a59eeeae9 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest_deduping.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest_deduping.kt @@ -28,8 +28,6 @@ import org.thoughtcrime.securesms.providers.BlobProvider import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.util.MediaUtil import org.thoughtcrime.securesms.util.Util -import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult -import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId import org.whispersystems.signalservice.api.push.ServiceId import org.whispersystems.signalservice.internal.crypto.PaddingInputStream import java.io.File @@ -729,7 +727,7 @@ class AttachmentTableTest_deduping { fun upload(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()) { SignalDatabase.attachments.createKeyIvIfNecessary(attachmentId) - SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, createUploadResult(attachmentId, uploadTimestamp)) + SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, AttachmentTableTestUtil.createUploadResult(attachmentId, uploadTimestamp)) val attachment = SignalDatabase.attachments.getAttachment(attachmentId)!! SignalDatabase.attachments.setArchiveCdn( @@ -875,23 +873,6 @@ class AttachmentTableTest_deduping { private fun ByteArray.asMediaStream(): MediaStream { return MediaStream(this.inputStream(), MediaUtil.IMAGE_JPEG, 2, 2) } - - private fun createUploadResult(attachmentId: AttachmentId, uploadTimestamp: Long = System.currentTimeMillis()): AttachmentUploadResult { - val databaseAttachment = SignalDatabase.attachments.getAttachment(attachmentId)!! - - return AttachmentUploadResult( - remoteId = SignalServiceAttachmentRemoteId.V4("somewhere-${Random.nextLong()}"), - cdnNumber = Cdn.CDN_3.cdnNumber, - key = databaseAttachment.remoteKey?.let { Base64.decode(it) } ?: Util.getSecretBytes(64), - iv = databaseAttachment.remoteIv ?: Util.getSecretBytes(16), - digest = Random.nextBytes(32), - incrementalDigest = Random.nextBytes(16), - incrementalDigestChunkSize = 5, - uploadTimestamp = uploadTimestamp, - dataSize = databaseAttachment.size, - blurHash = databaseAttachment.blurHash?.hash - ) - } } private fun test(content: TestContext.() -> Unit) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt index 234047c7a9..c62300d650 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt @@ -663,6 +663,20 @@ class AttachmentTable( } } + /** + * Sets the archive transfer state for the given attachment by digest. + */ + fun resetArchiveTransferStateByDigest(digest: ByteArray) { + writableDatabase + .update(TABLE_NAME) + .values( + ARCHIVE_TRANSFER_STATE to ArchiveTransferState.NONE.value, + ARCHIVE_CDN to 0 + ) + .where("$REMOTE_DIGEST = ?", digest) + .run() + } + /** * Sets the archive transfer state for the given attachment and all other attachments that share the same data file. */ diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/BackupMediaSnapshotTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/BackupMediaSnapshotTable.kt index 96508e72b8..86313f8e4b 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/BackupMediaSnapshotTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/BackupMediaSnapshotTable.kt @@ -6,17 +6,21 @@ package org.thoughtcrime.securesms.database import android.content.Context +import android.database.Cursor import androidx.annotation.VisibleForTesting import androidx.core.content.contentValuesOf import org.signal.core.util.SqlUtil import org.signal.core.util.delete import org.signal.core.util.readToList import org.signal.core.util.readToSet +import org.signal.core.util.requireBoolean import org.signal.core.util.requireInt import org.signal.core.util.requireNonNullBlob import org.signal.core.util.requireNonNullString import org.signal.core.util.select import org.signal.core.util.toInt +import org.signal.core.util.update +import org.signal.core.util.updateAll import org.thoughtcrime.securesms.backup.v2.ArchivedMediaObject /** @@ -60,6 +64,11 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat */ const val IS_THUMBNAIL = "is_thumbnail" + /** + * Timestamp when media was last seen on archive cdn. Can be reset to default. + */ + const val LAST_SEEN_ON_REMOTE_TIMESTAMP = "last_seen_on_remote_timestamp" + /** * The remote digest for the media object. This is used to find matching attachments in the attachment table when necessary. */ @@ -73,7 +82,8 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat $LAST_SYNC_TIME INTEGER DEFAULT 0, $PENDING_SYNC_TIME INTEGER, $IS_THUMBNAIL INTEGER DEFAULT 0, - $REMOTE_DIGEST BLOB NOT NULL + $REMOTE_DIGEST BLOB NOT NULL, + $LAST_SEEN_ON_REMOTE_TIMESTAMP INTEGER DEFAULT 0 ) """.trimIndent() } @@ -132,24 +142,30 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat return emptySet() } - val query = SqlUtil.buildSingleCollectionQuery( + val queries: List = SqlUtil.buildCollectionQuery( column = MEDIA_ID, values = objects.map { it.mediaId }, collectionOperator = SqlUtil.CollectionOperator.NOT_IN, prefix = "$IS_THUMBNAIL = 0 AND " ) - return readableDatabase - .select(MEDIA_ID, CDN) - .from(TABLE_NAME) - .where(query.where, query.whereArgs) - .run() - .readToSet { - ArchivedMediaObject( - mediaId = it.requireNonNullString(MEDIA_ID), - cdn = it.requireInt(CDN) - ) - } + val out: MutableSet = mutableSetOf() + + for (query in queries) { + out += readableDatabase + .select(MEDIA_ID, CDN) + .from(TABLE_NAME) + .where(query.where, query.whereArgs) + .run() + .readToSet { + ArchivedMediaObject( + mediaId = it.requireNonNullString(MEDIA_ID), + cdn = it.requireInt(CDN) + ) + } + } + + return out } /** @@ -177,6 +193,47 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat } } + /** + * Indicate the time that the set of media objects were seen on the archive CDN. Can be used to reconcile our local state with the server state. + */ + fun markSeenOnRemote(mediaIdBatch: Collection, time: Long) { + if (mediaIdBatch.isEmpty()) { + return + } + + val query = SqlUtil.buildFastCollectionQuery(MEDIA_ID, mediaIdBatch) + writableDatabase + .update(TABLE_NAME) + .values(LAST_SEEN_ON_REMOTE_TIMESTAMP to time) + .where(query.where, query.whereArgs) + .run() + } + + /** + * Get all media objects who were last seen on the remote server before the given time. + * This is used to find media objects that have not been seen on the CDN, even though they should be. + * + * The cursor contains rows that can be parsed into [MediaEntry] objects. + */ + fun getMediaObjectsLastSeenOnCdnBeforeTime(time: Long): Cursor { + return readableDatabase + .select(MEDIA_ID, CDN, REMOTE_DIGEST, IS_THUMBNAIL) + .from(TABLE_NAME) + .where("$LAST_SEEN_ON_REMOTE_TIMESTAMP < $time") + .run() + } + + /** + * Resets the [LAST_SEEN_ON_REMOTE_TIMESTAMP] column back to zero. It's a good idea to do this after you have run a sync and used the value, as it can + * mitigate various issues that can arise from having an incorrect local clock. + */ + fun clearLastSeenOnRemote() { + writableDatabase + .updateAll(TABLE_NAME) + .values(LAST_SEEN_ON_REMOTE_TIMESTAMP to 0) + .run() + } + private fun writePendingMediaObjectsChunk(chunk: List, pendingSyncTime: Long) { val values = chunk.map { contentValuesOf( @@ -213,10 +270,21 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat val cdn: Int ) - private data class MediaEntry( + class MediaEntry( val mediaId: String, val cdn: Int, val digest: ByteArray, val isThumbnail: Boolean - ) + ) { + companion object { + fun fromCursor(cursor: Cursor): MediaEntry { + return MediaEntry( + mediaId = cursor.requireNonNullString(MEDIA_ID), + cdn = cursor.requireInt(CDN), + digest = cursor.requireNonNullBlob(REMOTE_DIGEST), + isThumbnail = cursor.requireBoolean(IS_THUMBNAIL) + ) + } + } + } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SignalDatabaseMigrations.kt b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SignalDatabaseMigrations.kt index c45108badc..ba68ea6d3e 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SignalDatabaseMigrations.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SignalDatabaseMigrations.kt @@ -128,6 +128,7 @@ import org.thoughtcrime.securesms.database.helpers.migration.V270_FixChatFolderC import org.thoughtcrime.securesms.database.helpers.migration.V271_AddNotificationProfileIdColumn import org.thoughtcrime.securesms.database.helpers.migration.V272_UpdateUnreadCountIndices import org.thoughtcrime.securesms.database.helpers.migration.V273_FixUnreadOriginalMessages +import org.thoughtcrime.securesms.database.helpers.migration.V274_BackupMediaSnapshotLastSeenOnRemote import org.thoughtcrime.securesms.database.SQLiteDatabase as SignalSqliteDatabase /** @@ -261,10 +262,11 @@ object SignalDatabaseMigrations { 270 to V270_FixChatFolderColumnsForStorageSync, 271 to V271_AddNotificationProfileIdColumn, 272 to V272_UpdateUnreadCountIndices, - 273 to V273_FixUnreadOriginalMessages + 273 to V273_FixUnreadOriginalMessages, + 274 to V274_BackupMediaSnapshotLastSeenOnRemote ) - const val DATABASE_VERSION = 273 + const val DATABASE_VERSION = 274 @JvmStatic fun migrate(context: Application, db: SignalSqliteDatabase, oldVersion: Int, newVersion: Int) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/helpers/migration/V274_BackupMediaSnapshotLastSeenOnRemote.kt b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/migration/V274_BackupMediaSnapshotLastSeenOnRemote.kt new file mode 100644 index 0000000000..12da48e38b --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/migration/V274_BackupMediaSnapshotLastSeenOnRemote.kt @@ -0,0 +1,18 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.database.helpers.migration + +import android.app.Application +import org.thoughtcrime.securesms.database.SQLiteDatabase + +/** + * Added a column to the backup media snapshot table to keep track of the last time we saw an object on the CDN. + */ +object V274_BackupMediaSnapshotLastSeenOnRemote : SignalDatabaseMigration { + override fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) { + db.execSQL("ALTER TABLE backup_media_snapshot ADD COLUMN last_seen_on_remote_timestamp INTEGER DEFAULT 0") + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveAttachmentBackfillJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveAttachmentBackfillJob.kt index 9e4883611f..be370815c9 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveAttachmentBackfillJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveAttachmentBackfillJob.kt @@ -14,8 +14,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore import kotlin.time.Duration.Companion.days /** - * When run, this will find the next attachment that needs to be uploaded to the archive service and upload it. - * It will enqueue a copy of itself if it thinks there is more work to be done, and that copy will continue the upload process. + * When run, this will find all of the attachments that need to be uploaded to the archive tier and enqueue [UploadAttachmentToArchiveJob]s for them. */ class ArchiveAttachmentBackfillJob private constructor(parameters: Parameters) : Job(parameters) { companion object { diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/BackupMediaSnapshotSyncJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/BackupMediaSnapshotSyncJob.kt index 98a3684bce..6461fb1674 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/BackupMediaSnapshotSyncJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/BackupMediaSnapshotSyncJob.kt @@ -5,10 +5,12 @@ package org.thoughtcrime.securesms.jobs +import org.signal.core.util.forEach import org.signal.core.util.logging.Log import org.signal.core.util.nullIfBlank import org.thoughtcrime.securesms.backup.v2.ArchivedMediaObject import org.thoughtcrime.securesms.backup.v2.BackupRepository +import org.thoughtcrime.securesms.database.BackupMediaSnapshotTable import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.dependencies.AppDependencies import org.thoughtcrime.securesms.jobmanager.Job @@ -40,7 +42,8 @@ class BackupMediaSnapshotSyncJob private constructor( const val KEY = "BackupMediaSnapshotSyncJob" - private const val REMOTE_DELETE_BATCH_SIZE = 500 + private const val REMOTE_DELETE_BATCH_SIZE = 750 + private const val CDN_PAGE_SIZE = 10_000 private val BACKUP_MEDIA_SYNC_INTERVAL = 7.days.inWholeMilliseconds fun enqueue(syncTime: Long) { @@ -79,7 +82,9 @@ class BackupMediaSnapshotSyncJob private constructor( return syncDataFromCdn() ?: Result.success() } - override fun onFailure() = Unit + override fun onFailure() { + SignalDatabase.backupMediaSnapshots.clearLastSeenOnRemote() + } /** * Looks through our local snapshot of what attachments we put in the last backup file, and uses that to delete any old attachments from the archive CDN @@ -92,7 +97,7 @@ class BackupMediaSnapshotSyncJob private constructor( deleteMediaObjectsFromCdn(mediaObjects)?.let { result -> return result } SignalDatabase.backupMediaSnapshots.deleteMediaObjects(mediaObjects) - mediaObjects = SignalDatabase.backupMediaSnapshots.getPageOfOldMediaObjects(syncTime, REMOTE_DELETE_BATCH_SIZE) + mediaObjects = SignalDatabase.backupMediaSnapshots.getPageOfOldMediaObjects(syncTime, CDN_PAGE_SIZE) } return null @@ -135,6 +140,25 @@ class BackupMediaSnapshotSyncJob private constructor( deleteMediaObjectsFromCdn(attachmentsToDelete)?.let { result -> return result } } + val entriesNeedingRepairCursor = SignalDatabase.backupMediaSnapshots.getMediaObjectsLastSeenOnCdnBeforeTime(syncTime) + val needRepairCount = entriesNeedingRepairCursor.count + + if (needRepairCount > 0) { + Log.w(TAG, "Found $needRepairCount attachments that we thought were uploaded, but could not be found on the CDN. Clearing state and enqueuing uploads.") + + entriesNeedingRepairCursor.forEach { + val entry = BackupMediaSnapshotTable.MediaEntry.fromCursor(it) + // TODO [backup] Re-enqueue thumbnail uploads if necessary + if (!entry.isThumbnail) { + SignalDatabase.attachments.resetArchiveTransferStateByDigest(entry.digest) + } + } + + BackupMessagesJob.enqueue() + } else { + Log.d(TAG, "No attachments need to be repaired.") + } + SignalStore.backup.lastMediaSyncTime = System.currentTimeMillis() return null @@ -152,6 +176,11 @@ class BackupMediaSnapshotSyncJob private constructor( ) } + SignalDatabase.backupMediaSnapshots.markSeenOnRemote( + mediaIdBatch = mediaObjects.map { it.mediaId }, + time = syncTime + ) + val notFoundMediaObjects = SignalDatabase.backupMediaSnapshots.getMediaObjectsThatCantBeFound(mediaObjects) val remainingObjects = mediaObjects - notFoundMediaObjects diff --git a/core-util/build.gradle.kts b/core-util/build.gradle.kts index 4c0d731978..a4a18f5ad4 100644 --- a/core-util/build.gradle.kts +++ b/core-util/build.gradle.kts @@ -12,8 +12,10 @@ dependencies { implementation(libs.androidx.sqlite) implementation(libs.androidx.documentfile) + testImplementation(libs.androidx.sqlite.framework) testImplementation(testLibs.junit.junit) + testImplementation(testLibs.assertk) testImplementation(testLibs.robolectric.robolectric) } diff --git a/core-util/src/main/java/org/signal/core/util/SQLiteDatabaseExtensions.kt b/core-util/src/main/java/org/signal/core/util/SQLiteDatabaseExtensions.kt index 37428ee90b..9b76e5899e 100644 --- a/core-util/src/main/java/org/signal/core/util/SQLiteDatabaseExtensions.kt +++ b/core-util/src/main/java/org/signal/core/util/SQLiteDatabaseExtensions.kt @@ -6,6 +6,7 @@ import android.database.sqlite.SQLiteDatabase import androidx.core.content.contentValuesOf import androidx.sqlite.db.SupportSQLiteDatabase import androidx.sqlite.db.SupportSQLiteQueryBuilder +import androidx.sqlite.db.SupportSQLiteStatement import org.signal.core.util.SqlUtil.ForeignKeyViolation import org.signal.core.util.logging.Log import kotlin.time.Duration @@ -246,10 +247,34 @@ fun SupportSQLiteDatabase.deleteAll(tableName: String): Int { return this.delete(tableName, null, arrayOfNulls(0)) } +/** + * Begins an INSERT statement with a helpful builder pattern. + */ fun SupportSQLiteDatabase.insertInto(tableName: String): InsertBuilderPart1 { return InsertBuilderPart1(this, tableName) } +/** + * Bind an arbitrary value to an index. It will handle calling the correct bind method based on the class type. + * @param index The index you want to bind to. Important: Indexes start at 1, not 0. + */ +fun SupportSQLiteStatement.bindValue(index: Int, value: Any?) { + when (value) { + null -> this.bindNull(index) + is DatabaseId -> this.bindString(index, value.serialize()) + is Boolean -> this.bindLong(index, value.toInt().toLong()) + is ByteArray -> this.bindBlob(index, value) + is Number -> { + if (value.toLong() == value) { + this.bindLong(index, value.toLong()) + } else { + this.bindDouble(index, value.toDouble()) + } + } + else -> this.bindString(index, value.toString()) + } +} + class SelectBuilderPart1( private val db: SupportSQLiteDatabase, private val columns: Array @@ -422,7 +447,7 @@ class UpdateBuilderPart2( ) { fun where(where: String, vararg whereArgs: Any): UpdateBuilderPart3 { require(where.isNotBlank()) - return UpdateBuilderPart3(db, tableName, values, where, SqlUtil.buildArgs(*whereArgs)) + return UpdateBuilderPart3(db, tableName, values, where, whereArgs.toArgs()) } fun where(where: String, whereArgs: Array): UpdateBuilderPart3 { @@ -436,11 +461,45 @@ class UpdateBuilderPart3( private val tableName: String, private val values: ContentValues, private val where: String, - private val whereArgs: Array + private val whereArgs: Array ) { @JvmOverloads fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_NONE): Int { - return db.update(tableName, conflictStrategy, values, where, whereArgs) + val query = StringBuilder("UPDATE $tableName SET ") + + val contentValuesKeys = values.keySet() + for ((index, column) in contentValuesKeys.withIndex()) { + query.append(column).append(" = ?") + if (index < contentValuesKeys.size - 1) { + query.append(", ") + } + } + + query.append(" WHERE ").append(where) + + val conflictString = when (conflictStrategy) { + SQLiteDatabase.CONFLICT_IGNORE -> " ON CONFLICT IGNORE" + SQLiteDatabase.CONFLICT_ABORT -> " ON CONFLICT ABORT" + SQLiteDatabase.CONFLICT_FAIL -> " ON CONFLICT FAIL" + SQLiteDatabase.CONFLICT_ROLLBACK -> " ON CONFLICT ROLLBACK" + SQLiteDatabase.CONFLICT_REPLACE -> " ON CONFLICT REPLACE" + else -> "" + } + query.append(conflictString) + + val statement = db.compileStatement(query.toString()) + var bindIndex = 1 + for (key in contentValuesKeys) { + statement.bindValue(bindIndex, values.get(key)) + bindIndex++ + } + + for (arg in whereArgs) { + statement.bindValue(bindIndex, arg) + bindIndex++ + } + + return statement.use { it.executeUpdateDelete() } } } @@ -550,6 +609,20 @@ class InsertBuilderPart2( } } +/** + * Helper function to massage passed-in arguments into a better form to give to the database. + */ +private fun Array.toArgs(): Array { + return this + .map { + when (it) { + is DatabaseId -> it.serialize() + else -> it + } + } + .toTypedArray() +} + data class ForeignKeyConstraint( val table: String, val column: String, diff --git a/core-util/src/test/java/org/signal/core/util/InMemorySqliteOpenHelper.kt b/core-util/src/test/java/org/signal/core/util/InMemorySqliteOpenHelper.kt new file mode 100644 index 0000000000..395b844ff7 --- /dev/null +++ b/core-util/src/test/java/org/signal/core/util/InMemorySqliteOpenHelper.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.core.util + +import androidx.sqlite.db.SupportSQLiteDatabase +import androidx.sqlite.db.SupportSQLiteOpenHelper +import androidx.sqlite.db.framework.FrameworkSQLiteOpenHelperFactory +import androidx.test.core.app.ApplicationProvider + +/** + * Helper to create an in-memory database used for testing SQLite stuff. + */ +object InMemorySqliteOpenHelper { + fun create( + onCreate: (db: SupportSQLiteDatabase) -> Unit, + onUpgrade: (db: SupportSQLiteDatabase, oldVersion: Int, newVersion: Int) -> Unit = { _, _, _ -> } + ): SupportSQLiteOpenHelper { + val configuration = SupportSQLiteOpenHelper.Configuration( + context = ApplicationProvider.getApplicationContext(), + name = "test", + callback = object : SupportSQLiteOpenHelper.Callback(1) { + override fun onCreate(db: SupportSQLiteDatabase) = onCreate(db) + override fun onUpgrade(db: SupportSQLiteDatabase, oldVersion: Int, newVersion: Int) = onUpgrade(db, oldVersion, newVersion) + }, + useNoBackupDirectory = false, + allowDataLossOnRecovery = true + ) + + return FrameworkSQLiteOpenHelperFactory().create(configuration) + } +} diff --git a/core-util/src/test/java/org/signal/core/util/SQLiteDatabaseExtensionsTest.kt b/core-util/src/test/java/org/signal/core/util/SQLiteDatabaseExtensionsTest.kt new file mode 100644 index 0000000000..0dfc28fbd6 --- /dev/null +++ b/core-util/src/test/java/org/signal/core/util/SQLiteDatabaseExtensionsTest.kt @@ -0,0 +1,118 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.core.util + +import android.app.Application +import androidx.sqlite.db.SupportSQLiteOpenHelper +import assertk.assertThat +import assertk.assertions.isEqualTo +import assertk.assertions.isNotNull +import org.junit.Assert.assertArrayEquals +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config + +@RunWith(RobolectricTestRunner::class) +@Config(manifest = Config.NONE, application = Application::class) +class SQLiteDatabaseExtensionsTest { + + lateinit var db: SupportSQLiteOpenHelper + + companion object { + const val TABLE_NAME = "test" + const val ID = "_id" + const val STRING_COLUMN = "string_column" + const val LONG_COLUMN = "long_column" + const val DOUBLE_COLUMN = "double_column" + const val BLOB_COLUMN = "blob_column" + } + + @Before + fun setup() { + db = InMemorySqliteOpenHelper.create( + onCreate = { db -> + db.execSQL("CREATE TABLE $TABLE_NAME ($ID INTEGER PRIMARY KEY AUTOINCREMENT, $STRING_COLUMN TEXT, $LONG_COLUMN INTEGER, $DOUBLE_COLUMN DOUBLE, $BLOB_COLUMN BLOB)") + } + ) + + db.writableDatabase.insertInto(TABLE_NAME) + .values( + STRING_COLUMN to "asdf", + LONG_COLUMN to 1, + DOUBLE_COLUMN to 0.5f, + BLOB_COLUMN to byteArrayOf(1, 2, 3) + ) + .run() + } + + @Test + fun `update - content values work`() { + val updateCount: Int = db.writableDatabase + .update("test") + .values( + STRING_COLUMN to "asdf2", + LONG_COLUMN to 2, + DOUBLE_COLUMN to 1.5f, + BLOB_COLUMN to byteArrayOf(4, 5, 6) + ) + .where("$ID = ?", 1) + .run() + + val record = readRecord(1) + + assertThat(updateCount).isEqualTo(1) + assertThat(record).isNotNull() + assertThat(record!!.id).isEqualTo(1) + assertThat(record.stringColumn).isEqualTo("asdf2") + assertThat(record.longColumn).isEqualTo(2) + assertThat(record.doubleColumn).isEqualTo(1.5f) + assertArrayEquals(record.blobColumn, byteArrayOf(4, 5, 6)) + } + + @Test + fun `update - querying by blob works`() { + val updateCount: Int = db.writableDatabase + .update("test") + .values( + STRING_COLUMN to "asdf2" + ) + .where("$BLOB_COLUMN = ?", byteArrayOf(1, 2, 3)) + .run() + + val record = readRecord(1) + + assertThat(updateCount).isEqualTo(1) + assertThat(record).isNotNull() + assertThat(record!!.stringColumn).isEqualTo("asdf2") + } + + private fun readRecord(id: Long): TestRecord? { + return db.readableDatabase + .select() + .from(TABLE_NAME) + .where("$ID = ?", id) + .run() + .readToSingleObject { + TestRecord( + id = it.requireLong(ID), + stringColumn = it.requireString(STRING_COLUMN), + longColumn = it.requireLong(LONG_COLUMN), + doubleColumn = it.requireFloat(DOUBLE_COLUMN), + blobColumn = it.requireBlob(BLOB_COLUMN) + ) + } + } + + class TestRecord( + val id: Long, + val stringColumn: String?, + val longColumn: Long, + val doubleColumn: Float, + val blobColumn: ByteArray? + ) +}