Fix local backup restore AEP handling and conditional re-enable.

This commit is contained in:
Alex Hart
2026-03-19 12:50:00 -03:00
committed by Cody Henthorne
parent c7a6c7ad9e
commit 78d3db319c
16 changed files with 528 additions and 58 deletions

View File

@@ -0,0 +1,179 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.backup.v2.local
import android.app.Application
import android.content.Context
import androidx.test.core.app.ApplicationProvider
import assertk.assertThat
import assertk.assertions.isFalse
import assertk.assertions.isTrue
import okio.ByteString.Companion.toByteString
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.annotation.Config
import org.signal.core.models.backup.BackupId
import org.signal.core.models.backup.MessageBackupKey
import org.signal.core.util.Util
import org.thoughtcrime.securesms.backup.v2.local.proto.Metadata
import org.thoughtcrime.securesms.backup.v2.proto.AccountData
import org.thoughtcrime.securesms.backup.v2.proto.BackupInfo
import org.thoughtcrime.securesms.backup.v2.proto.Frame
import org.thoughtcrime.securesms.backup.v2.stream.EncryptedBackupWriter
import java.io.ByteArrayOutputStream
import javax.crypto.Cipher
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
@RunWith(RobolectricTestRunner::class)
@Config(manifest = Config.NONE, application = Application::class)
class LocalArchiverTest {
@get:Rule
val temporaryFolder = TemporaryFolder()
private val context: Context = ApplicationProvider.getApplicationContext()
@Test
fun `canDecryptMainArchive returns true for valid key`() {
val messageBackupKey = MessageBackupKey(Util.getSecretBytes(32))
val snapshot = createSnapshot()
writeSnapshotFiles(snapshot, messageBackupKey)
assertThat(LocalArchiver.canDecryptMainArchive(snapshot, messageBackupKey)).isTrue()
}
@Test
fun `canDecryptMainArchive returns false for wrong key`() {
val validKey = MessageBackupKey(Util.getSecretBytes(32))
val wrongKey = MessageBackupKey(Util.getSecretBytes(32))
val snapshot = createSnapshot()
writeSnapshotFiles(snapshot, validKey)
assertThat(LocalArchiver.canDecryptMainArchive(snapshot, wrongKey)).isFalse()
}
@Test
fun `canDecryptMainArchive returns false when metadata is missing`() {
val messageBackupKey = MessageBackupKey(Util.getSecretBytes(32))
val snapshot = createSnapshot()
writeMainArchive(snapshot, messageBackupKey, BackupId(Util.getSecretBytes(16)))
assertThat(LocalArchiver.canDecryptMainArchive(snapshot, messageBackupKey)).isFalse()
}
@Test
fun `canDecryptMainArchive returns false when main archive is corrupted`() {
val messageBackupKey = MessageBackupKey(Util.getSecretBytes(32))
val snapshot = createSnapshot()
writeSnapshotFiles(snapshot, messageBackupKey, corruptMainArchive = true)
assertThat(LocalArchiver.canDecryptMainArchive(snapshot, messageBackupKey)).isFalse()
}
@Test
fun `getBackupId returns correct id for valid key`() {
val messageBackupKey = MessageBackupKey(Util.getSecretBytes(32))
val snapshot = createSnapshot()
val backupId = BackupId(Util.getSecretBytes(16))
snapshot.metadataOutputStream()!!.use { it.write(createMetadata(messageBackupKey, backupId).encode()) }
val result = LocalArchiver.getBackupId(snapshot, messageBackupKey)
assertThat(result?.value?.contentEquals(backupId.value) == true).isTrue()
}
@Test
fun `getBackupId returns null when metadata is missing`() {
val messageBackupKey = MessageBackupKey(Util.getSecretBytes(32))
val snapshot = createSnapshot()
assertThat(LocalArchiver.getBackupId(snapshot, messageBackupKey) == null).isTrue()
}
@Test
fun `getBackupId returns wrong id for wrong key`() {
val validKey = MessageBackupKey(Util.getSecretBytes(32))
val wrongKey = MessageBackupKey(Util.getSecretBytes(32))
val snapshot = createSnapshot()
val backupId = BackupId(Util.getSecretBytes(16))
snapshot.metadataOutputStream()!!.use { it.write(createMetadata(validKey, backupId).encode()) }
val result = LocalArchiver.getBackupId(snapshot, wrongKey)
assertThat(result?.value?.contentEquals(backupId.value) == true).isFalse()
}
private fun createSnapshot(): SnapshotFileSystem {
val archiveRoot = temporaryFolder.newFolder()
return ArchiveFileSystem.fromFile(context, archiveRoot).createSnapshot()!!
}
private fun writeSnapshotFiles(
snapshot: SnapshotFileSystem,
messageBackupKey: MessageBackupKey,
corruptMainArchive: Boolean = false
) {
val backupId = BackupId(Util.getSecretBytes(16))
snapshot.metadataOutputStream()!!.use { it.write(createMetadata(messageBackupKey, backupId).encode()) }
writeMainArchive(snapshot, messageBackupKey, backupId, corruptMainArchive)
}
private fun writeMainArchive(
snapshot: SnapshotFileSystem,
messageBackupKey: MessageBackupKey,
backupId: BackupId,
corruptMainArchive: Boolean = false
) {
val output = ByteArrayOutputStream()
EncryptedBackupWriter.createForLocalOrLinking(
key = messageBackupKey,
backupId = backupId,
outputStream = output,
append = { output.write(it) }
).use { writer ->
writer.write(BackupInfo(version = 1, backupTimeMs = 1000L))
writer.write(Frame(account = AccountData(username = "username")))
}
val bytes = output.toByteArray()
if (corruptMainArchive) {
bytes[bytes.lastIndex] = bytes.last().xor(0x01)
}
snapshot.mainOutputStream()!!.use { it.write(bytes) }
}
private fun createMetadata(messageBackupKey: MessageBackupKey, backupId: BackupId): Metadata {
val metadataKey = messageBackupKey.deriveLocalBackupMetadataKey()
val iv = Util.getSecretBytes(12)
val cipherText = Cipher.getInstance("AES/CTR/NoPadding").run {
init(Cipher.ENCRYPT_MODE, SecretKeySpec(metadataKey, "AES"), IvParameterSpec(iv))
doFinal(backupId.value)
}
return Metadata(
version = 1,
backupId = Metadata.EncryptedBackupId(
iv = iv.toByteString(),
encryptedId = cipherText.toByteString()
)
)
}
private fun Byte.xor(mask: Int): Byte {
return (toInt() xor mask).toByte()
}
}

View File

@@ -8,6 +8,7 @@ package org.thoughtcrime.securesms.backup.v2.stream
import org.junit.Assert.assertEquals
import org.junit.Test
import org.signal.core.models.ServiceId.ACI
import org.signal.core.models.backup.BackupId
import org.signal.core.models.backup.MessageBackupKey
import org.signal.core.util.Base64
import org.signal.core.util.Hex
@@ -100,6 +101,36 @@ class EncryptedBackupReaderWriterTest {
assertEquals(count, uniqueOutputs.size)
}
@Test
fun `can read back all frames using BackupId directly - local`() {
val key = MessageBackupKey(Util.getSecretBytes(32))
val backupId = BackupId(Util.getSecretBytes(16))
val outputStream = ByteArrayOutputStream()
val frameCount = 10_000
EncryptedBackupWriter.createForLocalOrLinking(key, backupId, outputStream, append = { outputStream.write(it) }).use { writer ->
writer.write(BackupInfo(version = 1, backupTimeMs = 1000L))
for (i in 0 until frameCount) {
writer.write(Frame(account = AccountData(username = "username-$i")))
}
}
val ciphertext: ByteArray = outputStream.toByteArray()
val frames: List<Frame> = EncryptedBackupReader.createForLocalOrLinking(key, backupId, ciphertext.size.toLong()) { ciphertext.inputStream() }.use { reader ->
assertEquals(reader.backupInfo?.version, 1L)
assertEquals(reader.backupInfo?.backupTimeMs, 1000L)
reader.asSequence().toList()
}
assertEquals(frameCount, frames.size)
for (i in 0 until frameCount) {
assertEquals("username-$i", frames[i].account?.username)
}
}
@Test
fun `can read back all of the frames we write - forward secrecy`() {
val key = MessageBackupKey(Util.getSecretBytes(32))
@@ -140,4 +171,73 @@ class EncryptedBackupReaderWriterTest {
assertEquals("username-$i", frames[i].account?.username)
}
}
@Test
fun `can read back all frames using BackupId directly - forward secrecy`() {
val key = MessageBackupKey(Util.getSecretBytes(32))
val backupId = BackupId(Util.getSecretBytes(16))
val forwardSecrecyToken = BackupForwardSecrecyToken(Util.getSecretBytes(32))
val outputStream = ByteArrayOutputStream()
val frameCount = 10_000
EncryptedBackupWriter.createForSignalBackup(
key = key,
backupId = backupId,
forwardSecrecyToken = forwardSecrecyToken,
forwardSecrecyMetadata = Util.getSecretBytes(64),
outputStream = outputStream,
append = { outputStream.write(it) }
).use { writer ->
writer.write(BackupInfo(version = 1, backupTimeMs = 1000L))
for (i in 0 until frameCount) {
writer.write(Frame(account = AccountData(username = "username-$i")))
}
}
val ciphertext: ByteArray = outputStream.toByteArray()
val frames: List<Frame> = EncryptedBackupReader.createForSignalBackup(key, backupId, forwardSecrecyToken, ciphertext.size.toLong()) { ciphertext.inputStream() }.use { reader ->
assertEquals(reader.backupInfo?.version, 1L)
assertEquals(reader.backupInfo?.backupTimeMs, 1000L)
reader.asSequence().toList()
}
assertEquals(frameCount, frames.size)
for (i in 0 until frameCount) {
assertEquals("username-$i", frames[i].account?.username)
}
}
@Test
fun `can write and read using BackupId for both - local`() {
val key = MessageBackupKey(Util.getSecretBytes(32))
val backupId = BackupId(Util.getSecretBytes(16))
val outputStream = ByteArrayOutputStream()
val frameCount = 10_000
EncryptedBackupWriter.createForLocalOrLinking(key, backupId, outputStream, append = { outputStream.write(it) }).use { writer ->
writer.write(BackupInfo(version = 1, backupTimeMs = 1000L))
for (i in 0 until frameCount) {
writer.write(Frame(account = AccountData(username = "username-$i")))
}
}
val ciphertext: ByteArray = outputStream.toByteArray()
val frames: List<Frame> = EncryptedBackupReader.createForLocalOrLinking(key, backupId, ciphertext.size.toLong()) { ciphertext.inputStream() }.use { reader ->
assertEquals(reader.backupInfo?.version, 1L)
assertEquals(reader.backupInfo?.backupTimeMs, 1000L)
reader.asSequence().toList()
}
assertEquals(frameCount, frames.size)
for (i in 0 until frameCount) {
assertEquals("username-$i", frames[i].account?.username)
}
}
}