Remove orphaned attachments when creating a new backup.

This commit is contained in:
Alex Hart
2024-11-22 10:50:42 -04:00
committed by Greyson Parrelli
parent bae86d127f
commit c7f226b5cc
12 changed files with 464 additions and 5 deletions

View File

@@ -0,0 +1,118 @@
package org.thoughtcrime.securesms.database
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.signal.core.util.count
import org.signal.core.util.readToSingleInt
import org.thoughtcrime.securesms.backup.v2.ArchivedMediaObject
import org.thoughtcrime.securesms.testing.SignalActivityRule
import org.thoughtcrime.securesms.testing.assertIs
@RunWith(AndroidJUnit4::class)
class BackupMediaSnapshotTableTest {
companion object {
private const val SEQUENCE_COUNT = 100
}
@get:Rule
val harness = SignalActivityRule()
@Test
fun givenAnEmptyTable_whenIWriteToTable_thenIExpectEmptyTable() {
val pendingSyncTime = 1L
SignalDatabase.backupMediaSnapshots.writePendingMediaObjects(generateArchiveObjectSequence(), pendingSyncTime)
val count = getSyncedItemCount(pendingSyncTime)
count.assertIs(0)
}
@Test
fun givenAnEmptyTable_whenIWriteToTableAndCommit_thenIExpectFilledTable() {
val pendingSyncTime = 1L
SignalDatabase.backupMediaSnapshots.writePendingMediaObjects(generateArchiveObjectSequence(), pendingSyncTime)
SignalDatabase.backupMediaSnapshots.commitPendingRows()
val count = getSyncedItemCount(pendingSyncTime)
count.assertIs(SEQUENCE_COUNT)
}
@Test
fun givenAFilledTable_whenIInsertSimilarIds_thenIExpectUncommittedOverrides() {
SignalDatabase.backupMediaSnapshots.writePendingMediaObjects(generateArchiveObjectSequence(), 1L)
SignalDatabase.backupMediaSnapshots.commitPendingRows()
val newPendingTime = 2L
val newObjectCount = 50
SignalDatabase.backupMediaSnapshots.writePendingMediaObjects(generateArchiveObjectSequence(newObjectCount), newPendingTime)
val count = SignalDatabase.backupMediaSnapshots.readableDatabase.count()
.from(BackupMediaSnapshotTable.TABLE_NAME)
.where("${BackupMediaSnapshotTable.LAST_SYNC_TIME} = 1 AND ${BackupMediaSnapshotTable.PENDING_SYNC_TIME} = $newPendingTime")
.run()
.readToSingleInt(-1)
count.assertIs(50)
}
@Test
fun givenAFilledTable_whenIInsertSimilarIdsAndCommit_thenIExpectCommittedOverrides() {
SignalDatabase.backupMediaSnapshots.writePendingMediaObjects(generateArchiveObjectSequence(), 1L)
SignalDatabase.backupMediaSnapshots.commitPendingRows()
val newPendingTime = 2L
val newObjectCount = 50
SignalDatabase.backupMediaSnapshots.writePendingMediaObjects(generateArchiveObjectSequence(newObjectCount), newPendingTime)
SignalDatabase.backupMediaSnapshots.commitPendingRows()
val count = SignalDatabase.backupMediaSnapshots.readableDatabase.count()
.from(BackupMediaSnapshotTable.TABLE_NAME)
.where("${BackupMediaSnapshotTable.LAST_SYNC_TIME} = $newPendingTime AND ${BackupMediaSnapshotTable.PENDING_SYNC_TIME} = $newPendingTime")
.run()
.readToSingleInt(-1)
val total = getTotalItemCount()
count.assertIs(50)
total.assertIs(SEQUENCE_COUNT)
}
@Test
fun givenAFilledTable_whenIInsertSimilarIdsAndCommitThenDelete_thenIExpectOnlyCommittedOverrides() {
SignalDatabase.backupMediaSnapshots.writePendingMediaObjects(generateArchiveObjectSequence(), 1L)
SignalDatabase.backupMediaSnapshots.commitPendingRows()
val newPendingTime = 2L
val newObjectCount = 50
SignalDatabase.backupMediaSnapshots.writePendingMediaObjects(generateArchiveObjectSequence(newObjectCount), newPendingTime)
SignalDatabase.backupMediaSnapshots.commitPendingRows()
val page = SignalDatabase.backupMediaSnapshots.getPageOfOldMediaObjects(currentSyncTime = newPendingTime, pageSize = 100)
SignalDatabase.backupMediaSnapshots.deleteMediaObjects(page)
val total = getTotalItemCount()
total.assertIs(50)
}
private fun getTotalItemCount(): Int {
return SignalDatabase.backupMediaSnapshots.readableDatabase.count().from(BackupMediaSnapshotTable.TABLE_NAME).run().readToSingleInt(-1)
}
private fun getSyncedItemCount(pendingTime: Long): Int {
return SignalDatabase.backupMediaSnapshots.readableDatabase.count()
.from(BackupMediaSnapshotTable.TABLE_NAME)
.where("${BackupMediaSnapshotTable.LAST_SYNC_TIME} = $pendingTime AND ${BackupMediaSnapshotTable.PENDING_SYNC_TIME} = $pendingTime")
.run()
.readToSingleInt(-1)
}
private fun generateArchiveObjectSequence(count: Int = SEQUENCE_COUNT): Sequence<ArchivedMediaObject> {
return generateSequence(0) { seed -> if (seed < (count - 1)) seed + 1 else null }
.map { ArchivedMediaObject(mediaId = "media_id_$it", 0) }
}
}

View File

@@ -5,6 +5,7 @@
package org.thoughtcrime.securesms.backup.v2
import android.database.Cursor
import android.os.Environment
import android.os.StatFs
import androidx.annotation.WorkerThread
@@ -26,6 +27,8 @@ import org.signal.core.util.getAllTableDefinitions
import org.signal.core.util.getAllTriggerDefinitions
import org.signal.core.util.getForeignKeyViolations
import org.signal.core.util.logging.Log
import org.signal.core.util.requireInt
import org.signal.core.util.requireNonNullString
import org.signal.core.util.stream.NonClosingOutputStream
import org.signal.core.util.urlEncode
import org.signal.core.util.withinTransaction
@@ -420,7 +423,8 @@ object BackupRepository {
plaintext: Boolean = false,
currentTime: Long = System.currentTimeMillis(),
mediaBackupEnabled: Boolean = SignalStore.backup.backsUpMedia,
cancellationSignal: () -> Boolean = { false }
cancellationSignal: () -> Boolean = { false },
exportExtras: ((SignalDatabase) -> Unit)? = null
) {
val writer: BackupExportWriter = if (plaintext) {
PlainTextBackupWriter(outputStream)
@@ -433,7 +437,7 @@ object BackupRepository {
)
}
export(currentTime = currentTime, isLocal = false, writer = writer, mediaBackupEnabled = mediaBackupEnabled, cancellationSignal = cancellationSignal)
export(currentTime = currentTime, isLocal = false, writer = writer, mediaBackupEnabled = mediaBackupEnabled, cancellationSignal = cancellationSignal, exportExtras = exportExtras)
}
/**
@@ -1410,3 +1414,34 @@ sealed class ImportResult {
data class Success(val backupTime: Long) : ImportResult()
data object Failure : ImportResult()
}
/**
* Iterator that reads values from the given cursor. Expects that ARCHIVE_MEDIA_ID and ARCHIVE_CDN are both
* present and non-null in the cursor.
*
* This class does not assume ownership of the cursor. Recommended usage is within a use statement:
*
*
* ```
* databaseCall().use { cursor ->
* val iterator = ArchivedMediaObjectIterator(cursor)
* // Use the iterator...
* }
* // Cursor is closed after use block.
* ```
*/
class ArchivedMediaObjectIterator(private val cursor: Cursor) : Iterator<ArchivedMediaObject> {
init {
cursor.moveToFirst()
}
override fun hasNext(): Boolean = !cursor.isAfterLast
override fun next(): ArchivedMediaObject {
val mediaId = cursor.requireNonNullString(AttachmentTable.ARCHIVE_MEDIA_ID)
val cdn = cursor.requireInt(AttachmentTable.ARCHIVE_CDN)
cursor.moveToNext()
return ArchivedMediaObject(mediaId, cdn)
}
}

View File

@@ -407,6 +407,14 @@ class AttachmentTable(
}
}
fun getMediaIdCursor(): Cursor {
return readableDatabase
.select(ARCHIVE_MEDIA_ID, ARCHIVE_CDN)
.from(TABLE_NAME)
.where("$ARCHIVE_MEDIA_ID IS NOT NULL")
.run()
}
fun getAttachment(attachmentId: AttachmentId): DatabaseAttachment? {
return readableDatabase
.select(*PROJECTION)

View File

@@ -0,0 +1,121 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database
import android.content.Context
import androidx.annotation.VisibleForTesting
import androidx.core.content.contentValuesOf
import org.signal.core.util.SqlUtil
import org.signal.core.util.delete
import org.signal.core.util.exists
import org.signal.core.util.readToList
import org.signal.core.util.requireInt
import org.signal.core.util.requireNonNullString
import org.signal.core.util.select
import org.thoughtcrime.securesms.backup.v2.ArchivedMediaObject
/**
* Helper table for attachment deletion sync
*/
class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : DatabaseTable(context, database) {
companion object {
const val TABLE_NAME = "backup_media_snapshot"
private const val ID = "_id"
/**
* Generated media id matching that of the attachments table.
*/
private const val MEDIA_ID = "media_id"
/**
* CDN where the data is stored
*/
private const val CDN = "cdn"
/**
* Unique backup snapshot sync time. These are expected to increment in value
* where newer backups have a greater backup id value.
*/
@VisibleForTesting
const val LAST_SYNC_TIME = "last_sync_time"
/**
* Pending sync time, set while a backup is in the process of being exported.
*/
@VisibleForTesting
const val PENDING_SYNC_TIME = "pending_sync_time"
val CREATE_TABLE = """
CREATE TABLE $TABLE_NAME (
$ID INTEGER PRIMARY KEY,
$MEDIA_ID TEXT UNIQUE,
$CDN INTEGER,
$LAST_SYNC_TIME INTEGER DEFAULT 0,
$PENDING_SYNC_TIME INTEGER
)
""".trimIndent()
private const val ON_MEDIA_ID_CONFLICT = """
ON CONFLICT($MEDIA_ID) DO UPDATE SET
$PENDING_SYNC_TIME = EXCLUDED.$PENDING_SYNC_TIME,
$CDN = EXCLUDED.$CDN
"""
}
/**
* Creates the temporary table if it doesn't exist, clears it, then inserts the media objects into it.
*/
fun writePendingMediaObjects(mediaObjects: Sequence<ArchivedMediaObject>, pendingSyncTime: Long) {
mediaObjects.chunked(999)
.forEach { chunk ->
writePendingMediaObjectsChunk(chunk, pendingSyncTime)
}
}
private fun writePendingMediaObjectsChunk(chunk: List<ArchivedMediaObject>, pendingSyncTime: Long) {
SqlUtil.buildBulkInsert(
TABLE_NAME,
arrayOf(MEDIA_ID, CDN, PENDING_SYNC_TIME),
chunk.map {
contentValuesOf(MEDIA_ID to it.mediaId, CDN to it.cdn, PENDING_SYNC_TIME to pendingSyncTime)
}
).forEach {
writableDatabase.execSQL("${it.where} $ON_MEDIA_ID_CONFLICT", it.whereArgs)
}
}
/**
* Copies all entries from the temporary table to the persistent table, then deletes the temporary table.
*/
fun commitPendingRows() {
writableDatabase.execSQL("UPDATE $TABLE_NAME SET $LAST_SYNC_TIME = $PENDING_SYNC_TIME")
}
fun getPageOfOldMediaObjects(currentSyncTime: Long, pageSize: Int): List<ArchivedMediaObject> {
return readableDatabase.select(MEDIA_ID, CDN)
.from(TABLE_NAME)
.where("$LAST_SYNC_TIME < ? AND $LAST_SYNC_TIME = $PENDING_SYNC_TIME", currentSyncTime)
.limit(pageSize)
.run()
.readToList {
ArchivedMediaObject(mediaId = it.requireNonNullString(MEDIA_ID), cdn = it.requireInt(CDN))
}
}
fun deleteMediaObjects(mediaObjects: List<ArchivedMediaObject>) {
SqlUtil.buildCollectionQuery(MEDIA_ID, mediaObjects.map { it.mediaId }).forEach {
writableDatabase.delete(TABLE_NAME)
.where(it.where, it.whereArgs)
.run()
}
}
fun hasOldMediaObjects(currentSyncTime: Long): Boolean {
return readableDatabase.exists(TABLE_NAME).where("$LAST_SYNC_TIME > ? AND $LAST_SYNC_TIME = $PENDING_SYNC_TIME", currentSyncTime).run()
}
}

View File

@@ -77,6 +77,7 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
val inAppPaymentTable: InAppPaymentTable = InAppPaymentTable(context, this)
val inAppPaymentSubscriberTable: InAppPaymentSubscriberTable = InAppPaymentSubscriberTable(context, this)
val chatFoldersTable: ChatFolderTables = ChatFolderTables(context, this)
val backupMediaSnapshotTable: BackupMediaSnapshotTable = BackupMediaSnapshotTable(context, this)
override fun onOpen(db: net.zetetic.database.sqlcipher.SQLiteDatabase) {
db.setForeignKeyConstraintsEnabled(true)
@@ -122,6 +123,7 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
executeStatements(db, NotificationProfileDatabase.CREATE_TABLE)
executeStatements(db, DistributionListTables.CREATE_TABLE)
executeStatements(db, ChatFolderTables.CREATE_TABLE)
db.execSQL(BackupMediaSnapshotTable.CREATE_TABLE)
executeStatements(db, RecipientTable.CREATE_INDEXS)
executeStatements(db, MessageTable.CREATE_INDEXS)
@@ -566,5 +568,10 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
@get:JvmName("chatFolders")
val chatFolders: ChatFolderTables
get() = instance!!.chatFoldersTable
@get:JvmStatic
@get:JvmName("backupMediaSnapshots")
val backupMediaSnapshots: BackupMediaSnapshotTable
get() = instance!!.backupMediaSnapshotTable
}
}

View File

@@ -113,6 +113,7 @@ import org.thoughtcrime.securesms.database.helpers.migration.V253_CreateChatFold
import org.thoughtcrime.securesms.database.helpers.migration.V254_AddChatFolderConstraint
import org.thoughtcrime.securesms.database.helpers.migration.V255_AddCallTableLogIndex
import org.thoughtcrime.securesms.database.helpers.migration.V256_FixIncrementalDigestColumns
import org.thoughtcrime.securesms.database.helpers.migration.V257_CreateBackupMediaSyncTable
/**
* Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness.
@@ -228,10 +229,11 @@ object SignalDatabaseMigrations {
253 to V253_CreateChatFolderTables,
254 to V254_AddChatFolderConstraint,
255 to V255_AddCallTableLogIndex,
256 to V256_FixIncrementalDigestColumns
256 to V256_FixIncrementalDigestColumns,
257 to V257_CreateBackupMediaSyncTable
)
const val DATABASE_VERSION = 256
const val DATABASE_VERSION = 257
@JvmStatic
fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {

View File

@@ -0,0 +1,26 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database.helpers.migration
import android.app.Application
import net.zetetic.database.sqlcipher.SQLiteDatabase
@Suppress("ClassName")
object V257_CreateBackupMediaSyncTable : SignalDatabaseMigration {
override fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
db.execSQL(
"""
CREATE TABLE backup_media_snapshot (
_id INTEGER PRIMARY KEY,
media_id TEXT UNIQUE,
cdn INTEGER,
last_sync_time INTEGER DEFAULT 0,
pending_sync_time INTEGER
)
""".trimIndent()
)
}
}

View File

@@ -0,0 +1,76 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.jobs
import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.jobmanager.Job
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint
import org.thoughtcrime.securesms.jobs.protos.BackupMediaSnapshotSyncJobData
import org.whispersystems.signalservice.api.NetworkResult
/**
* Synchronizes the server media via bulk deletions of old attachments not present
* in the user's current backup.
*/
class BackupMediaSnapshotSyncJob private constructor(private val syncTime: Long, parameters: Parameters) : Job(parameters) {
companion object {
private val TAG = Log.tag(BackupMediaSnapshotSyncJob::class)
const val KEY = "BackupMediaSnapshotSyncJob"
private const val PAGE_SIZE = 500
fun enqueue(backupSnapshotId: Long) {
AppDependencies.jobManager.add(
BackupMediaSnapshotSyncJob(
backupSnapshotId,
Parameters.Builder()
.addConstraint(NetworkConstraint.KEY)
.setMaxInstancesForFactory(1)
.build()
)
)
}
}
override fun serialize(): ByteArray = BackupMediaSnapshotSyncJobData(syncTime).encode()
override fun getFactoryKey(): String = KEY
override fun run(): Result {
while (SignalDatabase.backupMediaSnapshots.hasOldMediaObjects(syncTime)) {
val mediaObjects = SignalDatabase.backupMediaSnapshots.getPageOfOldMediaObjects(syncTime, PAGE_SIZE)
when (val networkResult = BackupRepository.deleteAbandonedMediaObjects(mediaObjects)) {
is NetworkResult.Success -> {
SignalDatabase.backupMediaSnapshots.deleteMediaObjects(mediaObjects)
}
else -> {
Log.w(TAG, "Failed to delete media objects.", networkResult.getCause())
return Result.failure()
}
}
}
return Result.success()
}
override fun onFailure() = Unit
class Factory : Job.Factory<BackupMediaSnapshotSyncJob> {
override fun create(parameters: Parameters, serializedData: ByteArray?): BackupMediaSnapshotSyncJob {
val syncTime: Long = BackupMediaSnapshotSyncJobData.ADAPTER.decode(serializedData!!).syncTime
return BackupMediaSnapshotSyncJob(syncTime, parameters)
}
}
}

View File

@@ -9,6 +9,7 @@ import org.signal.core.util.Stopwatch
import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.backup.ArchiveUploadProgress
import org.thoughtcrime.securesms.backup.v2.ArchiveValidator
import org.thoughtcrime.securesms.backup.v2.ArchivedMediaObjectIterator
import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.AppDependencies
@@ -84,7 +85,11 @@ class BackupMessagesJob private constructor(parameters: Parameters) : Job(parame
val outputStream = FileOutputStream(tempBackupFile)
val backupKey = SignalStore.backup.messageBackupKey
BackupRepository.export(outputStream = outputStream, messageBackupKey = backupKey, append = { tempBackupFile.appendBytes(it) }, plaintext = false, cancellationSignal = { this.isCanceled })
val currentTime = System.currentTimeMillis()
BackupRepository.export(outputStream = outputStream, messageBackupKey = backupKey, append = { tempBackupFile.appendBytes(it) }, plaintext = false, cancellationSignal = { this.isCanceled }, currentTime = currentTime) {
writeMediaCursorToTemporaryTable(it, currentTime = currentTime, mediaBackupEnabled = SignalStore.backup.backsUpMedia)
}
stopwatch.split("export")
when (val result = ArchiveValidator.validate(tempBackupFile, backupKey)) {
@@ -156,9 +161,22 @@ class BackupMessagesJob private constructor(parameters: Parameters) : Job(parame
}
SignalStore.backup.clearMessageBackupFailure()
SignalDatabase.backupMediaSnapshots.commitPendingRows()
BackupMediaSnapshotSyncJob.enqueue(currentTime)
return Result.success()
}
private fun writeMediaCursorToTemporaryTable(db: SignalDatabase, mediaBackupEnabled: Boolean, currentTime: Long) {
if (mediaBackupEnabled) {
db.attachmentTable.getMediaIdCursor().use {
SignalDatabase.backupMediaSnapshots.writePendingMediaObjects(
mediaObjects = ArchivedMediaObjectIterator(it).asSequence(),
pendingSyncTime = currentTime
)
}
}
}
class Factory : Job.Factory<BackupMessagesJob> {
override fun create(parameters: Parameters, serializedData: ByteArray?): BackupMessagesJob {
return BackupMessagesJob(parameters)

View File

@@ -271,6 +271,7 @@ public final class JobManagerFactories {
put(BackfillDigestsMigrationJob.KEY, new BackfillDigestsMigrationJob.Factory());
put(BackfillDigestsForDuplicatesMigrationJob.KEY, new BackfillDigestsForDuplicatesMigrationJob.Factory());
put(BackupJitterMigrationJob.KEY, new BackupJitterMigrationJob.Factory());
put(BackupMediaSnapshotSyncJob.KEY, new BackupMediaSnapshotSyncJob.Factory());
put(BackupNotificationMigrationJob.KEY, new BackupNotificationMigrationJob.Factory());
put(BackupRefreshJob.KEY, new BackupRefreshJob.Factory());
put(BlobStorageLocationMigrationJob.KEY, new BlobStorageLocationMigrationJob.Factory());

View File

@@ -137,3 +137,7 @@ message UploadAttachmentToArchiveJobData {
uint64 attachmentId = 1;
ResumableUpload uploadSpec = 2;
}
message BackupMediaSnapshotSyncJobData {
uint64 syncTime = 1;
}

View File

@@ -0,0 +1,43 @@
package org.thoughtcrime.securesms.backup.v2
import org.junit.Before
import org.junit.Test
import org.mockito.kotlin.any
import org.mockito.kotlin.mock
import org.mockito.kotlin.whenever
import org.thoughtcrime.securesms.MockCursor
import org.thoughtcrime.securesms.assertIsSize
class ArchivedMediaObjectIteratorTest {
private val cursor: MockCursor = mock()
@Before
fun setUp() {
whenever(cursor.getString(any())).thenReturn("A")
whenever(cursor.moveToPosition(any())).thenCallRealMethod()
whenever(cursor.moveToNext()).thenCallRealMethod()
whenever(cursor.position).thenCallRealMethod()
whenever(cursor.isLast).thenCallRealMethod()
whenever(cursor.isAfterLast).thenCallRealMethod()
}
@Test
fun `Given a cursor with 0 items, when I convert to a list, then I expect a size of 0`() {
runTest(0)
}
@Test
fun `Given a cursor with 100 items, when I convert to a list, then I expect a size of 100`() {
runTest(100)
}
private fun runTest(size: Int) {
whenever(cursor.count).thenReturn(size)
val iterator = ArchivedMediaObjectIterator(cursor)
val list = iterator.asSequence().toList()
list.assertIsSize(size)
}
}