diff --git a/app/src/main/java/org/thoughtcrime/securesms/backup/FullBackupExporter.java b/app/src/main/java/org/thoughtcrime/securesms/backup/FullBackupExporter.java index 0bc390dc57..29f8026eff 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/backup/FullBackupExporter.java +++ b/app/src/main/java/org/thoughtcrime/securesms/backup/FullBackupExporter.java @@ -8,6 +8,7 @@ import android.text.TextUtils; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import androidx.annotation.RequiresApi; +import androidx.annotation.VisibleForTesting; import androidx.documentfile.provider.DocumentFile; import com.annimon.stream.function.Predicate; @@ -19,6 +20,7 @@ import org.greenrobot.eventbus.EventBus; import org.signal.core.util.Conversions; import org.signal.core.util.CursorUtil; import org.signal.core.util.SetUtil; +import org.signal.core.util.SqlUtil; import org.signal.core.util.Stopwatch; import org.signal.core.util.logging.Log; import org.signal.libsignal.protocol.kdf.HKDF; @@ -62,10 +64,15 @@ import java.io.OutputStream; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; -import java.util.LinkedList; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; import javax.crypto.BadPaddingException; import javax.crypto.Cipher; @@ -84,7 +91,11 @@ public class FullBackupExporter extends FullBackupBase { private static final long IDENTITY_KEY_BACKUP_RECORD_COUNT = 2L; private static final long FINAL_MESSAGE_COUNT = 1L; - private static final Set BLACKLISTED_TABLES = SetUtil.newHashSet( + /** + * Tables in list will still have their *schema* exported (so the tables will be created), + * but we will not export the actual contents. + */ + private static final Set TABLE_CONTENT_BLOCKLIST = SetUtil.newHashSet( SignedPreKeyDatabase.TABLE_NAME, OneTimePreKeyDatabase.TABLE_NAME, SessionDatabase.TABLE_NAME, @@ -175,7 +186,7 @@ public class FullBackupExporter extends FullBackupBase { count = exportTable(table, input, outputStream, cursor -> isForNonExpiringMmsMessage(input, cursor.getLong(cursor.getColumnIndexOrThrow(AttachmentDatabase.MMS_ID))), (cursor, innerCount) -> exportAttachment(attachmentSecret, cursor, outputStream, innerCount, estimatedCount), count, estimatedCount, cancellationSignal); } else if (table.equals(StickerDatabase.TABLE_NAME)) { count = exportTable(table, input, outputStream, cursor -> true, (cursor, innerCount) -> exportSticker(attachmentSecret, cursor, outputStream, innerCount, estimatedCount), count, estimatedCount, cancellationSignal); - } else if (!BLACKLISTED_TABLES.contains(table) && !table.startsWith("sqlite_")) { + } else if (!TABLE_CONTENT_BLOCKLIST.contains(table)) { count = exportTable(table, input, outputStream, null, null, count, estimatedCount, cancellationSignal); } stopwatch.split("table::" + table); @@ -229,7 +240,7 @@ public class FullBackupExporter extends FullBackupBase { count += getCount(input, BackupCountQueries.getAttachmentCount()); } else if (table.equals(StickerDatabase.TABLE_NAME)) { count += getCount(input, "SELECT COUNT(*) FROM " + table); - } else if (!BLACKLISTED_TABLES.contains(table) && !table.startsWith("sqlite_")) { + } else if (!TABLE_CONTENT_BLOCKLIST.contains(table)) { count += getCount(input, "SELECT COUNT(*) FROM " + table); } } @@ -266,31 +277,112 @@ public class FullBackupExporter extends FullBackupBase { private static List exportSchema(@NonNull SQLiteDatabase input, @NonNull BackupFrameOutputStream outputStream) throws IOException { - List tables = new LinkedList<>(); + List tablesInOrder = getTablesToExportInOrder(input); - try (Cursor cursor = input.rawQuery("SELECT sql, name, type FROM sqlite_master", null)) { + Map createStatementsByTable = new HashMap<>(); + + try (Cursor cursor = input.rawQuery("SELECT sql, name, type FROM sqlite_master WHERE type = 'table' AND sql NOT NULL", null)) { while (cursor != null && cursor.moveToNext()) { String sql = cursor.getString(0); String name = cursor.getString(1); - String type = cursor.getString(2); - if (sql != null) { - boolean isSmsFtsSecretTable = name != null && !name.equals(SearchDatabase.SMS_FTS_TABLE_NAME) && name.startsWith(SearchDatabase.SMS_FTS_TABLE_NAME); - boolean isMmsFtsSecretTable = name != null && !name.equals(SearchDatabase.MMS_FTS_TABLE_NAME) && name.startsWith(SearchDatabase.MMS_FTS_TABLE_NAME); - boolean isEmojiFtsSecretTable = name != null && !name.equals(EmojiSearchDatabase.TABLE_NAME) && name.startsWith(EmojiSearchDatabase.TABLE_NAME); + createStatementsByTable.put(name, sql); + } + } - if (!isSmsFtsSecretTable && !isMmsFtsSecretTable && !isEmojiFtsSecretTable) { - if ("table".equals(type)) { - tables.add(name); - } + for (String table : tablesInOrder) { + String statement = createStatementsByTable.get(table); - outputStream.write(BackupProtos.SqlStatement.newBuilder().setStatement(cursor.getString(0)).build()); - } + if (statement != null) { + outputStream.write(BackupProtos.SqlStatement.newBuilder().setStatement(statement).build()); + } else { + throw new IOException("Failed to find a create statement for table: " + table); + } + } + + try (Cursor cursor = input.rawQuery("SELECT sql, name, type FROM sqlite_master where type != 'table' AND sql NOT NULL", null)) { + while (cursor != null && cursor.moveToNext()) { + String sql = cursor.getString(0); + String name = cursor.getString(1); + + if (isTableAllowed(name)) { + outputStream.write(BackupProtos.SqlStatement.newBuilder().setStatement(sql).build()); } } } - return tables; + return tablesInOrder; + } + + /** + * Returns the list of tables we should export, in the order they should be exported in. + * The order is chosen to ensure we won't violate any foreign key constraints when we import them. + */ + private static List getTablesToExportInOrder(@NonNull SQLiteDatabase input) { + List tables = SqlUtil.getAllTables(input) + .stream() + .filter(FullBackupExporter::isTableAllowed) + .sorted() + .collect(Collectors.toList()); + + + Map> dependsOn = new LinkedHashMap<>(); + for (String table : tables) { + dependsOn.put(table, SqlUtil.getForeignKeyDependencies(input, table)); + } + + return computeTableOrder(dependsOn); + } + + @VisibleForTesting + static List computeTableOrder(@NonNull Map> dependsOn) { + List rootNodes = dependsOn.keySet() + .stream() + .filter(table -> { + boolean nothingDependsOnIt = dependsOn.values().stream().noneMatch(it -> it.contains(table)); + return nothingDependsOnIt; + }) + .sorted() + .collect(Collectors.toList()); + + LinkedHashSet outputOrder = new LinkedHashSet<>(); + + for (String root : rootNodes) { + postOrderTraversal(root, dependsOn, outputOrder); + } + + return new ArrayList<>(outputOrder); + } + + private static void postOrderTraversal(String current, Map> dependsOn, LinkedHashSet outputOrder) { + Set dependencies = dependsOn.get(current); + + if (dependencies == null || dependencies.isEmpty()) { + outputOrder.add(current); + return; + } + + for (String dependency : dependencies) { + postOrderTraversal(dependency, dependsOn, outputOrder); + } + + outputOrder.add(current); + } + + private static boolean isTableAllowed(@Nullable String table) { + if (table == null) { + return true; + } + + boolean isReservedTable = table.startsWith("sqlite_"); + boolean isSmsFtsSecretTable = !table.equals(SearchDatabase.SMS_FTS_TABLE_NAME) && table.startsWith(SearchDatabase.SMS_FTS_TABLE_NAME); + boolean isMmsFtsSecretTable = !table.equals(SearchDatabase.MMS_FTS_TABLE_NAME) && table.startsWith(SearchDatabase.MMS_FTS_TABLE_NAME); + boolean isEmojiFtsSecretTable = !table.equals(EmojiSearchDatabase.TABLE_NAME) && table.startsWith(EmojiSearchDatabase.TABLE_NAME); + + return !isReservedTable && + !isSmsFtsSecretTable && + !isMmsFtsSecretTable && + !isEmojiFtsSecretTable; } private static int exportTable(@NonNull String table, diff --git a/app/src/test/java/org/thoughtcrime/securesms/backup/FullBackupExporterTest.kt b/app/src/test/java/org/thoughtcrime/securesms/backup/FullBackupExporterTest.kt new file mode 100644 index 0000000000..70295b378d --- /dev/null +++ b/app/src/test/java/org/thoughtcrime/securesms/backup/FullBackupExporterTest.kt @@ -0,0 +1,141 @@ +package org.thoughtcrime.securesms.backup + +import org.junit.Assert.assertEquals +import org.junit.Test + +class FullBackupExporterTest { + + @Test + fun `computeTableOrder - empty`() { + val order = FullBackupExporter.computeTableOrder(mapOf()) + + assertEquals(listOf(), order) + } + + /** + * A B C + */ + @Test + fun `computeTableOrder - no dependencies`() { + val order = FullBackupExporter.computeTableOrder( + mapOf( + "A" to setOf(), + "B" to setOf(), + "C" to setOf(), + ) + ) + + assertEquals(listOf("A", "B", "C"), order) + } + + /** + * C + * | + * B + * | + * A + */ + @Test + fun `computeTableOrder - single chain`() { + val order = FullBackupExporter.computeTableOrder( + mapOf( + "C" to setOf("B"), + "B" to setOf("A"), + ) + ) + + assertEquals(listOf("A", "B", "C"), order) + } + + /** + * ┌──F──┐ G H + * ┌─B ┌─E─┐ + * A C D + */ + @Test + fun `computeTableOrder - complex 1`() { + val order = FullBackupExporter.computeTableOrder( + mapOf( + "F" to setOf("B", "E"), + "B" to setOf("A"), + "E" to setOf("C", "D"), + "G" to setOf(), + "H" to setOf(), + "A" to setOf(), + "C" to setOf(), + "D" to setOf(), + ) + ) + + assertEquals(listOf("A", "B", "C", "D", "E", "F", "G", "H"), order) + } + + /** + * ┌────I────┐ + * │ | │ + * ┌─C─┐ E ┌─H─┐ + * │ │ | │ │ + * A B D F G + */ + @Test + fun `computeTableOrder - complex 2`() { + val order = FullBackupExporter.computeTableOrder( + mapOf( + "I" to setOf("C", "E", "H"), + "C" to setOf("A", "B"), + "E" to setOf("D"), + "H" to setOf("F", "G"), + "A" to setOf(), + "B" to setOf(), + "D" to setOf(), + "F" to setOf(), + "G" to setOf(), + ) + ) + + assertEquals(listOf("A", "B", "C", "D", "E", "F", "G", "H", "I"), order) + } + + /** + * ┌─C─┐ E ┌─H─┐ + * │ │ | │ │ + * A B D F G + */ + @Test + fun `computeTableOrder - multiple roots`() { + val order = FullBackupExporter.computeTableOrder( + mapOf( + "C" to setOf("A", "B"), + "E" to setOf("D"), + "H" to setOf("F", "G"), + "A" to setOf(), + "B" to setOf(), + "D" to setOf(), + "F" to setOf(), + "G" to setOf(), + ) + ) + + assertEquals(listOf("A", "B", "C", "D", "E", "F", "G", "H"), order) + } + + /** + * ┌─C─┐ D ┌─E─┐ + * │ │ | │ │ + * A B A A B + */ + @Test + fun `computeTableOrder - multiple roots, dupes across graphs`() { + val order = FullBackupExporter.computeTableOrder( + mapOf( + "C" to setOf("A", "B"), + "D" to setOf("A"), + "E" to setOf("A", "B"), + "A" to setOf(), + "B" to setOf(), + ) + ) + + assertEquals(listOf("A", "B", "C", "D", "E"), order) + } +} diff --git a/core-util/src/main/java/org/signal/core/util/SqlUtil.kt b/core-util/src/main/java/org/signal/core/util/SqlUtil.kt index 687b466a1a..b1164f925f 100644 --- a/core-util/src/main/java/org/signal/core/util/SqlUtil.kt +++ b/core-util/src/main/java/org/signal/core/util/SqlUtil.kt @@ -33,6 +33,17 @@ object SqlUtil { return tables } + /** + * Given a table, this will return a set of tables that it has a foreign key dependency on. + */ + @JvmStatic + fun getForeignKeyDependencies(db: SupportSQLiteDatabase, table: String): Set { + return db.query("PRAGMA foreign_key_list($table)") + .readToSet{ cursor -> + cursor.requireNonNullString("table") + } + } + @JvmStatic fun isEmpty(db: SupportSQLiteDatabase, table: String): Boolean { db.query("SELECT COUNT(*) FROM $table", null).use { cursor ->