Guarantee table export order is valid.

This commit is contained in:
Greyson Parrelli
2022-11-16 13:21:33 -05:00
committed by Alex Hart
parent 7c60c32918
commit cf00995b6f
3 changed files with 262 additions and 18 deletions

View File

@@ -8,6 +8,7 @@ import android.text.TextUtils;
import androidx.annotation.NonNull; import androidx.annotation.NonNull;
import androidx.annotation.Nullable; import androidx.annotation.Nullable;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
import androidx.annotation.VisibleForTesting;
import androidx.documentfile.provider.DocumentFile; import androidx.documentfile.provider.DocumentFile;
import com.annimon.stream.function.Predicate; import com.annimon.stream.function.Predicate;
@@ -19,6 +20,7 @@ import org.greenrobot.eventbus.EventBus;
import org.signal.core.util.Conversions; import org.signal.core.util.Conversions;
import org.signal.core.util.CursorUtil; import org.signal.core.util.CursorUtil;
import org.signal.core.util.SetUtil; import org.signal.core.util.SetUtil;
import org.signal.core.util.SqlUtil;
import org.signal.core.util.Stopwatch; import org.signal.core.util.Stopwatch;
import org.signal.core.util.logging.Log; import org.signal.core.util.logging.Log;
import org.signal.libsignal.protocol.kdf.HKDF; import org.signal.libsignal.protocol.kdf.HKDF;
@@ -62,10 +64,15 @@ import java.io.OutputStream;
import java.security.InvalidAlgorithmParameterException; import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException; import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.LinkedList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import javax.crypto.BadPaddingException; import javax.crypto.BadPaddingException;
import javax.crypto.Cipher; import javax.crypto.Cipher;
@@ -84,7 +91,11 @@ public class FullBackupExporter extends FullBackupBase {
private static final long IDENTITY_KEY_BACKUP_RECORD_COUNT = 2L; private static final long IDENTITY_KEY_BACKUP_RECORD_COUNT = 2L;
private static final long FINAL_MESSAGE_COUNT = 1L; private static final long FINAL_MESSAGE_COUNT = 1L;
private static final Set<String> BLACKLISTED_TABLES = SetUtil.newHashSet( /**
* Tables in list will still have their *schema* exported (so the tables will be created),
* but we will not export the actual contents.
*/
private static final Set<String> TABLE_CONTENT_BLOCKLIST = SetUtil.newHashSet(
SignedPreKeyDatabase.TABLE_NAME, SignedPreKeyDatabase.TABLE_NAME,
OneTimePreKeyDatabase.TABLE_NAME, OneTimePreKeyDatabase.TABLE_NAME,
SessionDatabase.TABLE_NAME, SessionDatabase.TABLE_NAME,
@@ -175,7 +186,7 @@ public class FullBackupExporter extends FullBackupBase {
count = exportTable(table, input, outputStream, cursor -> isForNonExpiringMmsMessage(input, cursor.getLong(cursor.getColumnIndexOrThrow(AttachmentDatabase.MMS_ID))), (cursor, innerCount) -> exportAttachment(attachmentSecret, cursor, outputStream, innerCount, estimatedCount), count, estimatedCount, cancellationSignal); count = exportTable(table, input, outputStream, cursor -> isForNonExpiringMmsMessage(input, cursor.getLong(cursor.getColumnIndexOrThrow(AttachmentDatabase.MMS_ID))), (cursor, innerCount) -> exportAttachment(attachmentSecret, cursor, outputStream, innerCount, estimatedCount), count, estimatedCount, cancellationSignal);
} else if (table.equals(StickerDatabase.TABLE_NAME)) { } else if (table.equals(StickerDatabase.TABLE_NAME)) {
count = exportTable(table, input, outputStream, cursor -> true, (cursor, innerCount) -> exportSticker(attachmentSecret, cursor, outputStream, innerCount, estimatedCount), count, estimatedCount, cancellationSignal); count = exportTable(table, input, outputStream, cursor -> true, (cursor, innerCount) -> exportSticker(attachmentSecret, cursor, outputStream, innerCount, estimatedCount), count, estimatedCount, cancellationSignal);
} else if (!BLACKLISTED_TABLES.contains(table) && !table.startsWith("sqlite_")) { } else if (!TABLE_CONTENT_BLOCKLIST.contains(table)) {
count = exportTable(table, input, outputStream, null, null, count, estimatedCount, cancellationSignal); count = exportTable(table, input, outputStream, null, null, count, estimatedCount, cancellationSignal);
} }
stopwatch.split("table::" + table); stopwatch.split("table::" + table);
@@ -229,7 +240,7 @@ public class FullBackupExporter extends FullBackupBase {
count += getCount(input, BackupCountQueries.getAttachmentCount()); count += getCount(input, BackupCountQueries.getAttachmentCount());
} else if (table.equals(StickerDatabase.TABLE_NAME)) { } else if (table.equals(StickerDatabase.TABLE_NAME)) {
count += getCount(input, "SELECT COUNT(*) FROM " + table); count += getCount(input, "SELECT COUNT(*) FROM " + table);
} else if (!BLACKLISTED_TABLES.contains(table) && !table.startsWith("sqlite_")) { } else if (!TABLE_CONTENT_BLOCKLIST.contains(table)) {
count += getCount(input, "SELECT COUNT(*) FROM " + table); count += getCount(input, "SELECT COUNT(*) FROM " + table);
} }
} }
@@ -266,31 +277,112 @@ public class FullBackupExporter extends FullBackupBase {
private static List<String> exportSchema(@NonNull SQLiteDatabase input, @NonNull BackupFrameOutputStream outputStream) private static List<String> exportSchema(@NonNull SQLiteDatabase input, @NonNull BackupFrameOutputStream outputStream)
throws IOException throws IOException
{ {
List<String> tables = new LinkedList<>(); List<String> tablesInOrder = getTablesToExportInOrder(input);
try (Cursor cursor = input.rawQuery("SELECT sql, name, type FROM sqlite_master", null)) { Map<String, String> createStatementsByTable = new HashMap<>();
try (Cursor cursor = input.rawQuery("SELECT sql, name, type FROM sqlite_master WHERE type = 'table' AND sql NOT NULL", null)) {
while (cursor != null && cursor.moveToNext()) { while (cursor != null && cursor.moveToNext()) {
String sql = cursor.getString(0); String sql = cursor.getString(0);
String name = cursor.getString(1); String name = cursor.getString(1);
String type = cursor.getString(2);
if (sql != null) { createStatementsByTable.put(name, sql);
boolean isSmsFtsSecretTable = name != null && !name.equals(SearchDatabase.SMS_FTS_TABLE_NAME) && name.startsWith(SearchDatabase.SMS_FTS_TABLE_NAME); }
boolean isMmsFtsSecretTable = name != null && !name.equals(SearchDatabase.MMS_FTS_TABLE_NAME) && name.startsWith(SearchDatabase.MMS_FTS_TABLE_NAME); }
boolean isEmojiFtsSecretTable = name != null && !name.equals(EmojiSearchDatabase.TABLE_NAME) && name.startsWith(EmojiSearchDatabase.TABLE_NAME);
if (!isSmsFtsSecretTable && !isMmsFtsSecretTable && !isEmojiFtsSecretTable) { for (String table : tablesInOrder) {
if ("table".equals(type)) { String statement = createStatementsByTable.get(table);
tables.add(name);
}
outputStream.write(BackupProtos.SqlStatement.newBuilder().setStatement(cursor.getString(0)).build()); if (statement != null) {
} outputStream.write(BackupProtos.SqlStatement.newBuilder().setStatement(statement).build());
} else {
throw new IOException("Failed to find a create statement for table: " + table);
}
}
try (Cursor cursor = input.rawQuery("SELECT sql, name, type FROM sqlite_master where type != 'table' AND sql NOT NULL", null)) {
while (cursor != null && cursor.moveToNext()) {
String sql = cursor.getString(0);
String name = cursor.getString(1);
if (isTableAllowed(name)) {
outputStream.write(BackupProtos.SqlStatement.newBuilder().setStatement(sql).build());
} }
} }
} }
return tables; return tablesInOrder;
}
/**
* Returns the list of tables we should export, in the order they should be exported in.
* The order is chosen to ensure we won't violate any foreign key constraints when we import them.
*/
private static List<String> getTablesToExportInOrder(@NonNull SQLiteDatabase input) {
List<String> tables = SqlUtil.getAllTables(input)
.stream()
.filter(FullBackupExporter::isTableAllowed)
.sorted()
.collect(Collectors.toList());
Map<String, Set<String>> dependsOn = new LinkedHashMap<>();
for (String table : tables) {
dependsOn.put(table, SqlUtil.getForeignKeyDependencies(input, table));
}
return computeTableOrder(dependsOn);
}
@VisibleForTesting
static List<String> computeTableOrder(@NonNull Map<String, Set<String>> dependsOn) {
List<String> rootNodes = dependsOn.keySet()
.stream()
.filter(table -> {
boolean nothingDependsOnIt = dependsOn.values().stream().noneMatch(it -> it.contains(table));
return nothingDependsOnIt;
})
.sorted()
.collect(Collectors.toList());
LinkedHashSet<String> outputOrder = new LinkedHashSet<>();
for (String root : rootNodes) {
postOrderTraversal(root, dependsOn, outputOrder);
}
return new ArrayList<>(outputOrder);
}
private static void postOrderTraversal(String current, Map<String, Set<String>> dependsOn, LinkedHashSet<String> outputOrder) {
Set<String> dependencies = dependsOn.get(current);
if (dependencies == null || dependencies.isEmpty()) {
outputOrder.add(current);
return;
}
for (String dependency : dependencies) {
postOrderTraversal(dependency, dependsOn, outputOrder);
}
outputOrder.add(current);
}
private static boolean isTableAllowed(@Nullable String table) {
if (table == null) {
return true;
}
boolean isReservedTable = table.startsWith("sqlite_");
boolean isSmsFtsSecretTable = !table.equals(SearchDatabase.SMS_FTS_TABLE_NAME) && table.startsWith(SearchDatabase.SMS_FTS_TABLE_NAME);
boolean isMmsFtsSecretTable = !table.equals(SearchDatabase.MMS_FTS_TABLE_NAME) && table.startsWith(SearchDatabase.MMS_FTS_TABLE_NAME);
boolean isEmojiFtsSecretTable = !table.equals(EmojiSearchDatabase.TABLE_NAME) && table.startsWith(EmojiSearchDatabase.TABLE_NAME);
return !isReservedTable &&
!isSmsFtsSecretTable &&
!isMmsFtsSecretTable &&
!isEmojiFtsSecretTable;
} }
private static int exportTable(@NonNull String table, private static int exportTable(@NonNull String table,

View File

@@ -0,0 +1,141 @@
package org.thoughtcrime.securesms.backup
import org.junit.Assert.assertEquals
import org.junit.Test
class FullBackupExporterTest {
@Test
fun `computeTableOrder - empty`() {
val order = FullBackupExporter.computeTableOrder(mapOf())
assertEquals(listOf<String>(), order)
}
/**
* A B C
*/
@Test
fun `computeTableOrder - no dependencies`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"A" to setOf(),
"B" to setOf(),
"C" to setOf(),
)
)
assertEquals(listOf("A", "B", "C"), order)
}
/**
* C
* |
* B
* |
* A
*/
@Test
fun `computeTableOrder - single chain`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"C" to setOf("B"),
"B" to setOf("A"),
)
)
assertEquals(listOf("A", "B", "C"), order)
}
/**
* ┌──F──┐ G H
* ┌─B ┌─E─┐
* A C D
*/
@Test
fun `computeTableOrder - complex 1`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"F" to setOf("B", "E"),
"B" to setOf("A"),
"E" to setOf("C", "D"),
"G" to setOf(),
"H" to setOf(),
"A" to setOf(),
"C" to setOf(),
"D" to setOf(),
)
)
assertEquals(listOf("A", "B", "C", "D", "E", "F", "G", "H"), order)
}
/**
* ┌────I────┐
* │ | │
* ┌─C─┐ E ┌─H─┐
* │ │ | │ │
* A B D F G
*/
@Test
fun `computeTableOrder - complex 2`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"I" to setOf("C", "E", "H"),
"C" to setOf("A", "B"),
"E" to setOf("D"),
"H" to setOf("F", "G"),
"A" to setOf(),
"B" to setOf(),
"D" to setOf(),
"F" to setOf(),
"G" to setOf(),
)
)
assertEquals(listOf("A", "B", "C", "D", "E", "F", "G", "H", "I"), order)
}
/**
* ┌─C─┐ E ┌─H─┐
* │ │ | │ │
* A B D F G
*/
@Test
fun `computeTableOrder - multiple roots`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"C" to setOf("A", "B"),
"E" to setOf("D"),
"H" to setOf("F", "G"),
"A" to setOf(),
"B" to setOf(),
"D" to setOf(),
"F" to setOf(),
"G" to setOf(),
)
)
assertEquals(listOf("A", "B", "C", "D", "E", "F", "G", "H"), order)
}
/**
* ┌─C─┐ D ┌─E─┐
* │ │ | │ │
* A B A A B
*/
@Test
fun `computeTableOrder - multiple roots, dupes across graphs`() {
val order = FullBackupExporter.computeTableOrder(
mapOf(
"C" to setOf("A", "B"),
"D" to setOf("A"),
"E" to setOf("A", "B"),
"A" to setOf(),
"B" to setOf(),
)
)
assertEquals(listOf("A", "B", "C", "D", "E"), order)
}
}

View File

@@ -33,6 +33,17 @@ object SqlUtil {
return tables return tables
} }
/**
* Given a table, this will return a set of tables that it has a foreign key dependency on.
*/
@JvmStatic
fun getForeignKeyDependencies(db: SupportSQLiteDatabase, table: String): Set<String> {
return db.query("PRAGMA foreign_key_list($table)")
.readToSet{ cursor ->
cursor.requireNonNullString("table")
}
}
@JvmStatic @JvmStatic
fun isEmpty(db: SupportSQLiteDatabase, table: String): Boolean { fun isEmpty(db: SupportSQLiteDatabase, table: String): Boolean {
db.query("SELECT COUNT(*) FROM $table", null).use { cursor -> db.query("SELECT COUNT(*) FROM $table", null).use { cursor ->