From 74619f6f8deb2e413963fb3212b489c8c3ad1ece Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Thu, 30 Nov 2023 16:25:47 -0500 Subject: [PATCH] Prevent nested SQL error handlers. --- .../database/SqlCipherErrorHandler.kt | 69 +++++++++++-------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/SqlCipherErrorHandler.kt b/app/src/main/java/org/thoughtcrime/securesms/database/SqlCipherErrorHandler.kt index deef3d0e4f..9802ae5d8c 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/SqlCipherErrorHandler.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/SqlCipherErrorHandler.kt @@ -11,6 +11,7 @@ import org.signal.core.util.logging.Log import org.thoughtcrime.securesms.crypto.DatabaseSecretProvider import org.thoughtcrime.securesms.dependencies.ApplicationDependencies import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference /** @@ -18,39 +19,53 @@ import java.util.concurrent.atomic.AtomicReference */ @Suppress("ClassName") class SqlCipherErrorHandler(private val databaseName: String) : DatabaseErrorHandler { + companion object { + private val TAG = Log.tag(SqlCipherErrorHandler::class.java) + + private val errorHandlingInProgress = AtomicBoolean(false) + } override fun onCorruption(db: SQLiteDatabase, message: String) { - val result: DiagnosticResults = runDiagnostics(ApplicationDependencies.getApplication(), db) - var lines: List = result.logs.split("\n") - lines = listOf("Database '$databaseName' corrupted!", "[sqlite] $message", "Diagnostics results:") + lines + if (errorHandlingInProgress.getAndSet(true)) { + Log.w(TAG, "Error handling already in progress, skipping.") + return + } - Log.e(TAG, "Database '$databaseName' corrupted!") - Log.e(TAG, "[sqlite] $message") - Log.e(TAG, "Diagnostic results:\n ${result.logs}") + try { + val result: DiagnosticResults = runDiagnostics(ApplicationDependencies.getApplication(), db) + var lines: List = result.logs.split("\n") + lines = listOf("Database '$databaseName' corrupted!", "[sqlite] $message", "Diagnostics results:") + lines - if (result is DiagnosticResults.Success) { - if (result.pragma1Passes && result.pragma2Passes) { - var endCount = 0 - while (db.inTransaction() && endCount < 10) { - db.endTransaction() - endCount++ + Log.e(TAG, "Database '$databaseName' corrupted!") + Log.e(TAG, "[sqlite] $message") + Log.e(TAG, "Diagnostic results:\n ${result.logs}") + + if (result is DiagnosticResults.Success) { + if (result.pragma1Passes && result.pragma2Passes) { + var endCount = 0 + while (db.inTransaction() && endCount < 10) { + db.endTransaction() + endCount++ + } + + attemptToClearFullTextSearchIndex(db) + throw DatabaseCorruptedError_BothChecksPass(lines) + } else if (!result.pragma1Passes && result.pragma2Passes) { + attemptToClearFullTextSearchIndex(db) + throw DatabaseCorruptedError_NormalCheckFailsCipherCheckPasses(lines) + } else if (result.pragma1Passes && !result.pragma2Passes) { + attemptToClearFullTextSearchIndex(db) + throw DatabaseCorruptedError_NormalCheckPassesCipherCheckFails(lines) + } else { + attemptToClearFullTextSearchIndex(db) + throw DatabaseCorruptedError_BothChecksFail(lines) } - - attemptToClearFullTextSearchIndex(db) - throw DatabaseCorruptedError_BothChecksPass(lines) - } else if (!result.pragma1Passes && result.pragma2Passes) { - attemptToClearFullTextSearchIndex(db) - throw DatabaseCorruptedError_NormalCheckFailsCipherCheckPasses(lines) - } else if (result.pragma1Passes && !result.pragma2Passes) { - attemptToClearFullTextSearchIndex(db) - throw DatabaseCorruptedError_NormalCheckPassesCipherCheckFails(lines) } else { attemptToClearFullTextSearchIndex(db) - throw DatabaseCorruptedError_BothChecksFail(lines) + throw DatabaseCorruptedError_FailedToRunChecks(lines) } - } else { - attemptToClearFullTextSearchIndex(db) - throw DatabaseCorruptedError_FailedToRunChecks(lines) + } finally { + errorHandlingInProgress.set(false) } } @@ -184,8 +199,4 @@ class SqlCipherErrorHandler(private val databaseName: String) : DatabaseErrorHan private class DatabaseCorruptedError_NormalCheckFailsCipherCheckPasses constructor(lines: List) : CustomTraceError(lines) private class DatabaseCorruptedError_NormalCheckPassesCipherCheckFails constructor(lines: List) : CustomTraceError(lines) private class DatabaseCorruptedError_FailedToRunChecks constructor(lines: List) : CustomTraceError(lines) - - companion object { - private val TAG = Log.tag(SqlCipherErrorHandler::class.java) - } }