Add initial support for backup and restore of message and media to staging.

Co-authored-by: Cody Henthorne <cody@signal.org>
This commit is contained in:
Clark
2024-04-12 11:57:34 -04:00
committed by Greyson Parrelli
parent 8617a074ad
commit 689eacd618
71 changed files with 3198 additions and 744 deletions
@@ -1008,17 +1008,47 @@ class ImportExportTest {
attachmentLocator = FilePointer.AttachmentLocator( attachmentLocator = FilePointer.AttachmentLocator(
cdnKey = "coolCdnKey", cdnKey = "coolCdnKey",
cdnNumber = 2, cdnNumber = 2,
uploadTimestamp = System.currentTimeMillis() uploadTimestamp = System.currentTimeMillis(),
key = (1..32).map { it.toByte() }.toByteArray().toByteString(),
size = 12345,
digest = (1..32).map { it.toByte() }.toByteArray().toByteString()
), ),
key = (1..32).map { it.toByte() }.toByteArray().toByteString(),
contentType = "image/png", contentType = "image/png",
size = 12345,
fileName = "very_cool_picture.png", fileName = "very_cool_picture.png",
width = 100, width = 100,
height = 200, height = 200,
caption = "Love this cool picture!", caption = "Love this cool picture!",
incrementalMacChunkSize = 0 incrementalMacChunkSize = 0
) ),
wasDownloaded = true
),
MessageAttachment(
pointer = FilePointer(
invalidAttachmentLocator = FilePointer.InvalidAttachmentLocator(),
contentType = "image/png",
width = 100,
height = 200,
caption = "Love this cool picture! Too bad u cant download it",
incrementalMacChunkSize = 0
),
wasDownloaded = false
),
MessageAttachment(
pointer = FilePointer(
backupLocator = FilePointer.BackupLocator(
"digestherebutimlazy",
cdnNumber = 3,
key = (1..32).map { it.toByte() }.toByteArray().toByteString(),
digest = (1..64).map { it.toByte() }.toByteArray().toByteString(),
size = 12345
),
contentType = "image/png",
width = 100,
height = 200,
caption = "Love this cool picture! Too bad u cant download it",
incrementalMacChunkSize = 0
),
wasDownloaded = true
) )
) )
) )
@@ -7,6 +7,7 @@ import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.signal.core.util.ThreadUtil import org.signal.core.util.ThreadUtil
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.PointerAttachment import org.thoughtcrime.securesms.attachments.PointerAttachment
import org.thoughtcrime.securesms.conversation.v2.ConversationActivity import org.thoughtcrime.securesms.conversation.v2.ConversationActivity
import org.thoughtcrime.securesms.database.MessageType import org.thoughtcrime.securesms.database.MessageType
@@ -15,7 +16,6 @@ import org.thoughtcrime.securesms.mms.IncomingMessage
import org.thoughtcrime.securesms.mms.OutgoingMessage import org.thoughtcrime.securesms.mms.OutgoingMessage
import org.thoughtcrime.securesms.profiles.ProfileName import org.thoughtcrime.securesms.profiles.ProfileName
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.releasechannel.ReleaseChannel
import org.thoughtcrime.securesms.testing.SignalActivityRule import org.thoughtcrime.securesms.testing.SignalActivityRule
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
@@ -137,7 +137,7 @@ class ConversationItemPreviewer {
private fun attachment(): SignalServiceAttachmentPointer { private fun attachment(): SignalServiceAttachmentPointer {
return SignalServiceAttachmentPointer( return SignalServiceAttachmentPointer(
ReleaseChannel.CDN_NUMBER, Cdn.CDN_3.cdnNumber,
SignalServiceAttachmentRemoteId.from(""), SignalServiceAttachmentRemoteId.from(""),
"image/webp", "image/webp",
null, null,
@@ -14,6 +14,7 @@ import org.junit.runner.RunWith
import org.signal.core.util.Base64 import org.signal.core.util.Base64
import org.signal.core.util.update import org.signal.core.util.update
import org.thoughtcrime.securesms.attachments.AttachmentId import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.PointerAttachment import org.thoughtcrime.securesms.attachments.PointerAttachment
import org.thoughtcrime.securesms.database.AttachmentTable.TransformProperties import org.thoughtcrime.securesms.database.AttachmentTable.TransformProperties
import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.keyvalue.SignalStore
@@ -742,7 +743,7 @@ class AttachmentTableTest_deduping {
assertArrayEquals(lhsAttachment.remoteDigest, rhsAttachment.remoteDigest) assertArrayEquals(lhsAttachment.remoteDigest, rhsAttachment.remoteDigest)
assertArrayEquals(lhsAttachment.incrementalDigest, rhsAttachment.incrementalDigest) assertArrayEquals(lhsAttachment.incrementalDigest, rhsAttachment.incrementalDigest)
assertEquals(lhsAttachment.incrementalMacChunkSize, rhsAttachment.incrementalMacChunkSize) assertEquals(lhsAttachment.incrementalMacChunkSize, rhsAttachment.incrementalMacChunkSize)
assertEquals(lhsAttachment.cdnNumber, rhsAttachment.cdnNumber) assertEquals(lhsAttachment.cdn.cdnNumber, rhsAttachment.cdn.cdnNumber)
} }
fun assertDoesNotHaveRemoteFields(attachmentId: AttachmentId) { fun assertDoesNotHaveRemoteFields(attachmentId: AttachmentId) {
@@ -751,7 +752,7 @@ class AttachmentTableTest_deduping {
assertNull(databaseAttachment.remoteLocation) assertNull(databaseAttachment.remoteLocation)
assertNull(databaseAttachment.remoteDigest) assertNull(databaseAttachment.remoteDigest)
assertNull(databaseAttachment.remoteKey) assertNull(databaseAttachment.remoteKey)
assertEquals(0, databaseAttachment.cdnNumber) assertEquals(0, databaseAttachment.cdn.cdnNumber)
} }
fun assertSkipTransform(attachmentId: AttachmentId, state: Boolean) { fun assertSkipTransform(attachmentId: AttachmentId, state: Boolean) {
@@ -776,7 +777,7 @@ class AttachmentTableTest_deduping {
AttachmentTable.TRANSFER_PROGRESS_DONE, AttachmentTable.TRANSFER_PROGRESS_DONE,
databaseAttachment.size, // size databaseAttachment.size, // size
null, null,
3, // cdnNumber Cdn.CDN_3, // cdnNumber
location, location,
key, key,
digest, digest,
+4
View File
@@ -1162,6 +1162,10 @@
android:name=".service.AttachmentProgressService" android:name=".service.AttachmentProgressService"
android:exported="false"/> android:exported="false"/>
<service
android:name=".service.BackupProgressService"
android:exported="false"/>
<service <service
android:name=".gcm.FcmFetchBackgroundService" android:name=".gcm.FcmFetchBackgroundService"
android:exported="false"/> android:exported="false"/>
@@ -0,0 +1,88 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.attachments
import android.net.Uri
import android.os.Parcel
import org.signal.core.util.Base64
import org.thoughtcrime.securesms.blurhash.BlurHash
import org.thoughtcrime.securesms.database.AttachmentTable
class ArchivedAttachment : Attachment {
@JvmField
val archiveCdn: Int
@JvmField
val archiveMediaName: String
@JvmField
val archiveMediaId: String
constructor(
contentType: String?,
size: Long,
cdn: Cdn,
cdnKey: ByteArray,
archiveMediaName: String,
archiveMediaId: String,
digest: ByteArray,
incrementalMac: ByteArray?,
incrementalMacChunkSize: Int?,
width: Int?,
height: Int?,
caption: String?,
blurHash: String?,
voiceNote: Boolean,
borderless: Boolean,
gif: Boolean,
quote: Boolean
) : super(
contentType = contentType ?: "",
quote = quote,
transferState = AttachmentTable.TRANSFER_NEEDS_RESTORE,
size = size,
fileName = null,
cdn = cdn,
remoteLocation = null,
remoteKey = Base64.encodeWithoutPadding(cdnKey),
remoteDigest = digest,
incrementalDigest = incrementalMac,
fastPreflightId = null,
voiceNote = voiceNote,
borderless = borderless,
videoGif = gif,
width = width ?: 0,
height = height ?: 0,
incrementalMacChunkSize = incrementalMacChunkSize ?: 0,
uploadTimestamp = 0,
caption = caption,
stickerLocator = null,
blurHash = BlurHash.parseOrNull(blurHash),
audioHash = null,
transformProperties = null
) {
this.archiveCdn = cdn.cdnNumber
this.archiveMediaName = archiveMediaName
this.archiveMediaId = archiveMediaId
}
constructor(parcel: Parcel) : super(parcel) {
archiveCdn = parcel.readInt()
archiveMediaName = parcel.readString()!!
archiveMediaId = parcel.readString()!!
}
override fun writeToParcel(dest: Parcel, flags: Int) {
super.writeToParcel(dest, flags)
dest.writeInt(archiveCdn)
dest.writeString(archiveMediaName)
dest.writeString(archiveMediaId)
}
override val uri: Uri? = null
override val publicUri: Uri? = null
}
@@ -29,7 +29,7 @@ abstract class Attachment(
@JvmField @JvmField
val fileName: String?, val fileName: String?,
@JvmField @JvmField
val cdnNumber: Int, val cdn: Cdn,
@JvmField @JvmField
val remoteLocation: String?, val remoteLocation: String?,
@JvmField @JvmField
@@ -76,7 +76,7 @@ abstract class Attachment(
transferState = parcel.readInt(), transferState = parcel.readInt(),
size = parcel.readLong(), size = parcel.readLong(),
fileName = parcel.readString(), fileName = parcel.readString(),
cdnNumber = parcel.readInt(), cdn = Cdn.deserialize(parcel.readInt()),
remoteLocation = parcel.readString(), remoteLocation = parcel.readString(),
remoteKey = parcel.readString(), remoteKey = parcel.readString(),
remoteDigest = ParcelUtil.readByteArray(parcel), remoteDigest = ParcelUtil.readByteArray(parcel),
@@ -103,7 +103,7 @@ abstract class Attachment(
dest.writeInt(transferState) dest.writeInt(transferState)
dest.writeLong(size) dest.writeLong(size)
dest.writeString(fileName) dest.writeString(fileName)
dest.writeInt(cdnNumber) dest.writeInt(cdn.serialize())
dest.writeString(remoteLocation) dest.writeString(remoteLocation)
dest.writeString(remoteKey) dest.writeString(remoteKey)
ParcelUtil.writeByteArray(dest, remoteDigest) ParcelUtil.writeByteArray(dest, remoteDigest)
@@ -17,7 +17,8 @@ object AttachmentCreator : Parcelable.Creator<Attachment> {
DATABASE(DatabaseAttachment::class.java, "database"), DATABASE(DatabaseAttachment::class.java, "database"),
POINTER(PointerAttachment::class.java, "pointer"), POINTER(PointerAttachment::class.java, "pointer"),
TOMBSTONE(TombstoneAttachment::class.java, "tombstone"), TOMBSTONE(TombstoneAttachment::class.java, "tombstone"),
URI(UriAttachment::class.java, "uri") URI(UriAttachment::class.java, "uri"),
ARCHIVED(ArchivedAttachment::class.java, "archived")
} }
@JvmStatic @JvmStatic
@@ -34,6 +35,7 @@ object AttachmentCreator : Parcelable.Creator<Attachment> {
Subclass.POINTER -> PointerAttachment(source) Subclass.POINTER -> PointerAttachment(source)
Subclass.TOMBSTONE -> TombstoneAttachment(source) Subclass.TOMBSTONE -> TombstoneAttachment(source)
Subclass.URI -> UriAttachment(source) Subclass.URI -> UriAttachment(source)
Subclass.ARCHIVED -> ArchivedAttachment(source)
} }
} }
@@ -0,0 +1,53 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.attachments
import org.signal.core.util.IntSerializer
/**
* Attachments/media can come from and go to multiple CDN locations depending on when and where
* they were uploaded. This class represents the CDNs where attachments/media can live.
*/
enum class Cdn(private val value: Int) {
S3(-1),
CDN_0(0),
CDN_2(2),
CDN_3(3);
val cdnNumber: Int
get() {
return when (this) {
S3 -> -1
CDN_0 -> 0
CDN_2 -> 2
CDN_3 -> 3
}
}
fun serialize(): Int {
return Serializer.serialize(this)
}
companion object Serializer : IntSerializer<Cdn> {
override fun serialize(data: Cdn): Int {
return data.value
}
override fun deserialize(data: Int): Cdn {
return values().first { it.value == data }
}
fun fromCdnNumber(cdnNumber: Int): Cdn {
return when (cdnNumber) {
-1 -> S3
0 -> CDN_0
2 -> CDN_2
3 -> CDN_3
else -> throw UnsupportedOperationException()
}
}
}
}
@@ -25,6 +25,15 @@ class DatabaseAttachment : Attachment {
@JvmField @JvmField
val dataHash: String? val dataHash: String?
@JvmField
val archiveCdn: Int
@JvmField
val archiveMediaName: String?
@JvmField
val archiveMediaId: String?
private val hasThumbnail: Boolean private val hasThumbnail: Boolean
val displayOrder: Int val displayOrder: Int
@@ -37,7 +46,7 @@ class DatabaseAttachment : Attachment {
transferProgress: Int, transferProgress: Int,
size: Long, size: Long,
fileName: String?, fileName: String?,
cdnNumber: Int, cdn: Cdn,
location: String?, location: String?,
key: String?, key: String?,
digest: ByteArray?, digest: ByteArray?,
@@ -57,13 +66,16 @@ class DatabaseAttachment : Attachment {
transformProperties: TransformProperties?, transformProperties: TransformProperties?,
displayOrder: Int, displayOrder: Int,
uploadTimestamp: Long, uploadTimestamp: Long,
dataHash: String? dataHash: String?,
archiveCdn: Int,
archiveMediaName: String?,
archiveMediaId: String?
) : super( ) : super(
contentType = contentType!!, contentType = contentType!!,
transferState = transferProgress, transferState = transferProgress,
size = size, size = size,
fileName = fileName, fileName = fileName,
cdnNumber = cdnNumber, cdn = cdn,
remoteLocation = location, remoteLocation = location,
remoteKey = key, remoteKey = key,
remoteDigest = digest, remoteDigest = digest,
@@ -88,6 +100,9 @@ class DatabaseAttachment : Attachment {
this.dataHash = dataHash this.dataHash = dataHash
this.hasThumbnail = hasThumbnail this.hasThumbnail = hasThumbnail
this.displayOrder = displayOrder this.displayOrder = displayOrder
this.archiveCdn = archiveCdn
this.archiveMediaName = archiveMediaName
this.archiveMediaId = archiveMediaId
} }
constructor(parcel: Parcel) : super(parcel) { constructor(parcel: Parcel) : super(parcel) {
@@ -97,6 +112,9 @@ class DatabaseAttachment : Attachment {
hasThumbnail = ParcelUtil.readBoolean(parcel) hasThumbnail = ParcelUtil.readBoolean(parcel)
mmsId = parcel.readLong() mmsId = parcel.readLong()
displayOrder = parcel.readInt() displayOrder = parcel.readInt()
archiveCdn = parcel.readInt()
archiveMediaName = parcel.readString()
archiveMediaId = parcel.readString()
} }
override fun writeToParcel(dest: Parcel, flags: Int) { override fun writeToParcel(dest: Parcel, flags: Int) {
@@ -107,6 +125,9 @@ class DatabaseAttachment : Attachment {
ParcelUtil.writeBoolean(dest, hasThumbnail) ParcelUtil.writeBoolean(dest, hasThumbnail)
dest.writeLong(mmsId) dest.writeLong(mmsId)
dest.writeInt(displayOrder) dest.writeInt(displayOrder)
dest.writeInt(archiveCdn)
dest.writeString(archiveMediaName)
dest.writeString(archiveMediaId)
} }
override val uri: Uri? override val uri: Uri?
@@ -9,7 +9,6 @@ import org.thoughtcrime.securesms.database.AttachmentTable
import org.thoughtcrime.securesms.stickers.StickerLocator import org.thoughtcrime.securesms.stickers.StickerLocator
import org.whispersystems.signalservice.api.InvalidMessageStructureException import org.whispersystems.signalservice.api.InvalidMessageStructureException
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage
import org.whispersystems.signalservice.api.util.AttachmentPointerUtil import org.whispersystems.signalservice.api.util.AttachmentPointerUtil
import org.whispersystems.signalservice.internal.push.DataMessage import org.whispersystems.signalservice.internal.push.DataMessage
import java.util.Optional import java.util.Optional
@@ -21,7 +20,7 @@ class PointerAttachment : Attachment {
transferState: Int, transferState: Int,
size: Long, size: Long,
fileName: String?, fileName: String?,
cdnNumber: Int, cdn: Cdn,
location: String, location: String,
key: String?, key: String?,
digest: ByteArray?, digest: ByteArray?,
@@ -42,7 +41,7 @@ class PointerAttachment : Attachment {
transferState = transferState, transferState = transferState,
size = size, size = size,
fileName = fileName, fileName = fileName,
cdnNumber = cdnNumber, cdn = cdn,
remoteLocation = location, remoteLocation = location,
remoteKey = key, remoteKey = key,
remoteDigest = digest, remoteDigest = digest,
@@ -83,7 +82,7 @@ class PointerAttachment : Attachment {
@JvmStatic @JvmStatic
@JvmOverloads @JvmOverloads
fun forPointer(pointer: Optional<SignalServiceAttachment>, stickerLocator: StickerLocator? = null, fastPreflightId: String? = null): Optional<Attachment> { fun forPointer(pointer: Optional<SignalServiceAttachment>, stickerLocator: StickerLocator? = null, fastPreflightId: String? = null, transferState: Int = AttachmentTable.TRANSFER_PROGRESS_PENDING): Optional<Attachment> {
if (!pointer.isPresent || !pointer.get().isPointer) { if (!pointer.isPresent || !pointer.get().isPointer) {
return Optional.empty() return Optional.empty()
} }
@@ -97,10 +96,10 @@ class PointerAttachment : Attachment {
return Optional.of( return Optional.of(
PointerAttachment( PointerAttachment(
contentType = pointer.get().contentType, contentType = pointer.get().contentType,
transferState = AttachmentTable.TRANSFER_PROGRESS_PENDING, transferState = transferState,
size = pointer.get().asPointer().size.orElse(0).toLong(), size = pointer.get().asPointer().size.orElse(0).toLong(),
fileName = pointer.get().asPointer().fileName.orElse(null), fileName = pointer.get().asPointer().fileName.orElse(null),
cdnNumber = pointer.get().asPointer().cdnNumber, cdn = Cdn.fromCdnNumber(pointer.get().asPointer().cdnNumber),
location = pointer.get().asPointer().remoteId.toString(), location = pointer.get().asPointer().remoteId.toString(),
key = encodedKey, key = encodedKey,
digest = pointer.get().asPointer().digest.orElse(null), digest = pointer.get().asPointer().digest.orElse(null),
@@ -120,35 +119,6 @@ class PointerAttachment : Attachment {
) )
} }
fun forPointer(pointer: SignalServiceDataMessage.Quote.QuotedAttachment): Optional<Attachment> {
val thumbnail = pointer.thumbnail
return Optional.of(
PointerAttachment(
contentType = pointer.contentType,
transferState = AttachmentTable.TRANSFER_PROGRESS_PENDING,
size = (if (thumbnail != null) thumbnail.asPointer().size.orElse(0) else 0).toLong(),
fileName = pointer.fileName,
cdnNumber = thumbnail?.asPointer()?.cdnNumber ?: 0,
location = thumbnail?.asPointer()?.remoteId?.toString() ?: "0",
key = if (thumbnail != null && thumbnail.asPointer().key != null) encodeWithPadding(thumbnail.asPointer().key) else null,
digest = thumbnail?.asPointer()?.digest?.orElse(null),
incrementalDigest = thumbnail?.asPointer()?.incrementalDigest?.orElse(null),
incrementalMacChunkSize = thumbnail?.asPointer()?.incrementalMacChunkSize ?: 0,
fastPreflightId = null,
voiceNote = false,
borderless = false,
videoGif = false,
width = thumbnail?.asPointer()?.width ?: 0,
height = thumbnail?.asPointer()?.height ?: 0,
uploadTimestamp = thumbnail?.asPointer()?.uploadTimestamp ?: 0,
caption = thumbnail?.asPointer()?.caption?.orElse(null),
stickerLocator = null,
blurHash = null
)
)
}
fun forPointer(quotedAttachment: DataMessage.Quote.QuotedAttachment): Optional<Attachment> { fun forPointer(quotedAttachment: DataMessage.Quote.QuotedAttachment): Optional<Attachment> {
val thumbnail: SignalServiceAttachment? = try { val thumbnail: SignalServiceAttachment? = try {
if (quotedAttachment.thumbnail != null) { if (quotedAttachment.thumbnail != null) {
@@ -166,7 +136,7 @@ class PointerAttachment : Attachment {
transferState = AttachmentTable.TRANSFER_PROGRESS_PENDING, transferState = AttachmentTable.TRANSFER_PROGRESS_PENDING,
size = (if (thumbnail != null) thumbnail.asPointer().size.orElse(0) else 0).toLong(), size = (if (thumbnail != null) thumbnail.asPointer().size.orElse(0) else 0).toLong(),
fileName = quotedAttachment.fileName, fileName = quotedAttachment.fileName,
cdnNumber = thumbnail?.asPointer()?.cdnNumber ?: 0, cdn = Cdn.fromCdnNumber(thumbnail?.asPointer()?.cdnNumber ?: 0),
location = thumbnail?.asPointer()?.remoteId?.toString() ?: "0", location = thumbnail?.asPointer()?.remoteId?.toString() ?: "0",
key = if (thumbnail != null && thumbnail.asPointer().key != null) encodeWithPadding(thumbnail.asPointer().key) else null, key = if (thumbnail != null && thumbnail.asPointer().key != null) encodeWithPadding(thumbnail.asPointer().key) else null,
digest = thumbnail?.asPointer()?.digest?.orElse(null), digest = thumbnail?.asPointer()?.digest?.orElse(null),
@@ -2,6 +2,7 @@ package org.thoughtcrime.securesms.attachments
import android.net.Uri import android.net.Uri
import android.os.Parcel import android.os.Parcel
import org.thoughtcrime.securesms.blurhash.BlurHash
import org.thoughtcrime.securesms.database.AttachmentTable import org.thoughtcrime.securesms.database.AttachmentTable
/** /**
@@ -17,7 +18,7 @@ class TombstoneAttachment : Attachment {
transferState = AttachmentTable.TRANSFER_PROGRESS_DONE, transferState = AttachmentTable.TRANSFER_PROGRESS_DONE,
size = 0, size = 0,
fileName = null, fileName = null,
cdnNumber = 0, cdn = Cdn.CDN_0,
remoteLocation = null, remoteLocation = null,
remoteKey = null, remoteKey = null,
remoteDigest = null, remoteDigest = null,
@@ -37,6 +38,44 @@ class TombstoneAttachment : Attachment {
transformProperties = null transformProperties = null
) )
constructor(
contentType: String?,
incrementalMac: ByteArray?,
incrementalMacChunkSize: Int?,
width: Int?,
height: Int?,
caption: String?,
blurHash: String?,
voiceNote: Boolean = false,
borderless: Boolean = false,
gif: Boolean = false,
quote: Boolean
) : super(
contentType = contentType ?: "",
quote = quote,
transferState = AttachmentTable.TRANSFER_PROGRESS_PERMANENT_FAILURE,
size = 0,
fileName = null,
cdn = Cdn.CDN_0,
remoteLocation = null,
remoteKey = null,
remoteDigest = null,
incrementalDigest = incrementalMac,
fastPreflightId = null,
voiceNote = voiceNote,
borderless = borderless,
videoGif = gif,
width = width ?: 0,
height = height ?: 0,
incrementalMacChunkSize = incrementalMacChunkSize ?: 0,
uploadTimestamp = 0,
caption = caption,
stickerLocator = null,
blurHash = BlurHash.parseOrNull(blurHash),
audioHash = null,
transformProperties = null
)
constructor(parcel: Parcel) : super(parcel) constructor(parcel: Parcel) : super(parcel)
override val uri: Uri? = null override val uri: Uri? = null
@@ -69,7 +69,7 @@ class UriAttachment : Attachment {
transferState = transferState, transferState = transferState,
size = size, size = size,
fileName = fileName, fileName = fileName,
cdnNumber = 0, cdn = Cdn.CDN_0,
remoteLocation = null, remoteLocation = null,
remoteKey = null, remoteKey = null,
remoteDigest = null, remoteDigest = null,
@@ -0,0 +1,36 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.backup
import org.signal.core.util.LongSerializer
enum class RestoreState(val id: Int, val inProgress: Boolean) {
FAILED(-1, false),
NONE(0, false),
PENDING(1, true),
RESTORING_DB(2, true),
RESTORING_MEDIA(3, true);
companion object {
val serializer: LongSerializer<RestoreState> = Serializer()
}
class Serializer : LongSerializer<RestoreState> {
override fun serialize(data: RestoreState): Long {
return data.id.toLong()
}
override fun deserialize(data: Long): RestoreState {
return when (data.toInt()) {
FAILED.id -> FAILED
PENDING.id -> PENDING
RESTORING_DB.id -> RESTORING_DB
RESTORING_MEDIA.id -> RESTORING_MEDIA
else -> NONE
}
}
}
}
@@ -1,5 +1,5 @@
/* /*
* Copyright 2023 Signal Messenger, LLC * Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@@ -14,6 +14,7 @@ import org.signal.libsignal.messagebackup.MessageBackup.ValidationResult
import org.signal.libsignal.messagebackup.MessageBackupKey import org.signal.libsignal.messagebackup.MessageBackupKey
import org.signal.libsignal.protocol.ServiceId.Aci import org.signal.libsignal.protocol.ServiceId.Aci
import org.signal.libsignal.zkgroup.profiles.ProfileKey import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.DatabaseAttachment import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.backup.v2.database.ChatItemImportInserter import org.thoughtcrime.securesms.backup.v2.database.ChatItemImportInserter
import org.thoughtcrime.securesms.backup.v2.database.clearAllDataForBackupRestore import org.thoughtcrime.securesms.backup.v2.database.clearAllDataForBackupRestore
@@ -37,17 +38,20 @@ import org.thoughtcrime.securesms.recipients.RecipientId
import org.whispersystems.signalservice.api.NetworkResult import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.archive.ArchiveGetMediaItemsResponse import org.whispersystems.signalservice.api.archive.ArchiveGetMediaItemsResponse
import org.whispersystems.signalservice.api.archive.ArchiveMediaRequest import org.whispersystems.signalservice.api.archive.ArchiveMediaRequest
import org.whispersystems.signalservice.api.archive.ArchiveMediaResponse
import org.whispersystems.signalservice.api.archive.ArchiveServiceCredential import org.whispersystems.signalservice.api.archive.ArchiveServiceCredential
import org.whispersystems.signalservice.api.archive.BatchArchiveMediaResponse
import org.whispersystems.signalservice.api.archive.DeleteArchivedMediaRequest import org.whispersystems.signalservice.api.archive.DeleteArchivedMediaRequest
import org.whispersystems.signalservice.api.archive.GetArchiveCdnCredentialsResponse
import org.whispersystems.signalservice.api.backup.BackupKey import org.whispersystems.signalservice.api.backup.BackupKey
import org.whispersystems.signalservice.api.backup.MediaName
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener
import org.whispersystems.signalservice.api.push.ServiceId.ACI import org.whispersystems.signalservice.api.push.ServiceId.ACI
import org.whispersystems.signalservice.api.push.ServiceId.PNI import org.whispersystems.signalservice.api.push.ServiceId.PNI
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.io.File
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream
import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.milliseconds
object BackupRepository { object BackupRepository {
@@ -55,10 +59,8 @@ object BackupRepository {
private val TAG = Log.tag(BackupRepository::class.java) private val TAG = Log.tag(BackupRepository::class.java)
private const val VERSION = 1L private const val VERSION = 1L
fun export(plaintext: Boolean = false): ByteArray { fun export(outputStream: OutputStream, append: (ByteArray) -> Unit, plaintext: Boolean = false) {
val eventTimer = EventTimer() val eventTimer = EventTimer()
val outputStream = ByteArrayOutputStream()
val writer: BackupExportWriter = if (plaintext) { val writer: BackupExportWriter = if (plaintext) {
PlainTextBackupWriter(outputStream) PlainTextBackupWriter(outputStream)
} else { } else {
@@ -66,11 +68,11 @@ object BackupRepository {
key = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey(), key = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey(),
aci = SignalStore.account().aci!!, aci = SignalStore.account().aci!!,
outputStream = outputStream, outputStream = outputStream,
append = { mac -> outputStream.write(mac) } append = append
) )
} }
val exportState = ExportState(System.currentTimeMillis()) val exportState = ExportState(backupTime = System.currentTimeMillis(), allowMediaBackup = true)
writer.use { writer.use {
writer.write( writer.write(
@@ -110,7 +112,11 @@ object BackupRepository {
} }
Log.d(TAG, "export() ${eventTimer.stop().summary}") Log.d(TAG, "export() ${eventTimer.stop().summary}")
}
fun export(plaintext: Boolean = false): ByteArray {
val outputStream = ByteArrayOutputStream()
export(outputStream = outputStream, append = { mac -> outputStream.write(mac) }, plaintext = plaintext)
return outputStream.toByteArray() return outputStream.toByteArray()
} }
@@ -124,11 +130,13 @@ object BackupRepository {
fun import(length: Long, inputStreamFactory: () -> InputStream, selfData: SelfData, plaintext: Boolean = false) { fun import(length: Long, inputStreamFactory: () -> InputStream, selfData: SelfData, plaintext: Boolean = false) {
val eventTimer = EventTimer() val eventTimer = EventTimer()
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
val frameReader = if (plaintext) { val frameReader = if (plaintext) {
PlainTextBackupReader(inputStreamFactory()) PlainTextBackupReader(inputStreamFactory())
} else { } else {
EncryptedBackupReader( EncryptedBackupReader(
key = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey(), key = backupKey,
aci = selfData.aci, aci = selfData.aci,
streamLength = length, streamLength = length,
dataStream = inputStreamFactory dataStream = inputStreamFactory
@@ -160,7 +168,7 @@ object BackupRepository {
SignalDatabase.recipients.setProfileSharing(selfId, true) SignalDatabase.recipients.setProfileSharing(selfId, true)
eventTimer.emit("setup") eventTimer.emit("setup")
val backupState = BackupState() val backupState = BackupState(backupKey)
val chatItemInserter: ChatItemImportInserter = ChatItemBackupProcessor.beginImport(backupState) val chatItemInserter: ChatItemImportInserter = ChatItemBackupProcessor.beginImport(backupState)
for (frame in frameReader) { for (frame in frameReader) {
@@ -281,6 +289,24 @@ object BackupRepository {
.also { Log.i(TAG, "OverallResult: $it") } is NetworkResult.Success .also { Log.i(TAG, "OverallResult: $it") } is NetworkResult.Success
} }
fun downloadBackupFile(destination: File, listener: ProgressListener? = null): Boolean {
val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
return api
.triggerBackupIdReservation(backupKey)
.then { getAuthCredential() }
.then { credential ->
api.getBackupInfo(backupKey, credential)
}
.then { info -> getCdnReadCredentials().map { it.headers to info } }
.map { pair ->
val (cdnCredentials, info) = pair
val messageReceiver = ApplicationDependencies.getSignalServiceMessageReceiver()
messageReceiver.retrieveBackup(info.cdn!!, cdnCredentials, "backups/${info.backupDir}/${info.backupName}", destination, listener)
} is NetworkResult.Success
}
/** /**
* Returns an object with details about the remote backup state. * Returns an object with details about the remote backup state.
*/ */
@@ -296,7 +322,7 @@ object BackupRepository {
} }
} }
fun archiveMedia(attachment: DatabaseAttachment): NetworkResult<ArchiveMediaResponse> { fun archiveMedia(attachment: DatabaseAttachment): NetworkResult<Unit> {
val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
@@ -304,16 +330,23 @@ object BackupRepository {
.triggerBackupIdReservation(backupKey) .triggerBackupIdReservation(backupKey)
.then { getAuthCredential() } .then { getAuthCredential() }
.then { credential -> .then { credential ->
api.archiveAttachmentMedia( val mediaName = attachment.getMediaName()
backupKey = backupKey, val request = attachment.toArchiveMediaRequest(mediaName, backupKey)
serviceCredential = credential, api
item = attachment.toArchiveMediaRequest(backupKey) .archiveAttachmentMedia(
) backupKey = backupKey,
serviceCredential = credential,
item = request
)
.map { Triple(mediaName, request.mediaId, it) }
} }
.also { Log.i(TAG, "backupMediaResult: $it") } .map { (mediaName, mediaId, response) ->
SignalDatabase.attachments.setArchiveData(attachmentId = attachment.attachmentId, archiveCdn = response.cdn, archiveMediaName = mediaName.name, archiveMediaId = mediaId)
}
.also { Log.i(TAG, "archiveMediaResult: $it") }
} }
fun archiveMedia(attachments: List<DatabaseAttachment>): NetworkResult<BatchArchiveMediaResponse> { fun archiveMedia(databaseAttachments: List<DatabaseAttachment>): NetworkResult<BatchArchiveMediaResult> {
val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
@@ -321,24 +354,55 @@ object BackupRepository {
.triggerBackupIdReservation(backupKey) .triggerBackupIdReservation(backupKey)
.then { getAuthCredential() } .then { getAuthCredential() }
.then { credential -> .then { credential ->
api.archiveAttachmentMedia( val requests = mutableListOf<ArchiveMediaRequest>()
backupKey = backupKey, val mediaIdToAttachmentId = mutableMapOf<String, AttachmentId>()
serviceCredential = credential, val attachmentIdToMediaName = mutableMapOf<AttachmentId, String>()
items = attachments.map { it.toArchiveMediaRequest(backupKey) }
) databaseAttachments.forEach {
val mediaName = it.getMediaName()
val request = it.toArchiveMediaRequest(mediaName, backupKey)
requests += request
mediaIdToAttachmentId[request.mediaId] = it.attachmentId
attachmentIdToMediaName[it.attachmentId] = mediaName.name
}
api
.archiveAttachmentMedia(
backupKey = backupKey,
serviceCredential = credential,
items = requests
)
.map { BatchArchiveMediaResult(it, mediaIdToAttachmentId, attachmentIdToMediaName) }
} }
.also { Log.i(TAG, "backupMediaResult: $it") } .map { result ->
result
.successfulResponses
.forEach {
val attachmentId = result.mediaIdToAttachmentId(it.mediaId)
val mediaName = result.attachmentIdToMediaName(attachmentId)
SignalDatabase.attachments.setArchiveData(attachmentId = attachmentId, archiveCdn = it.cdn!!, archiveMediaName = mediaName, archiveMediaId = it.mediaId)
}
result
}
.also { Log.i(TAG, "archiveMediaResult: $it") }
} }
fun deleteArchivedMedia(attachments: List<DatabaseAttachment>): NetworkResult<Unit> { fun deleteArchivedMedia(attachments: List<DatabaseAttachment>): NetworkResult<Unit> {
val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey() val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
val mediaToDelete = attachments.map { val mediaToDelete = attachments
DeleteArchivedMediaRequest.ArchivedMediaObject( .filter { it.archiveMediaId != null }
cdn = 3, // TODO [cody] store and reuse backup cdn returned from copy/move call .map {
mediaId = backupKey.deriveMediaId(Base64.decode(it.dataHash!!)).toString() DeleteArchivedMediaRequest.ArchivedMediaObject(
) cdn = it.archiveCdn,
mediaId = it.archiveMediaId!!
)
}
if (mediaToDelete.isEmpty()) {
Log.i(TAG, "No media to delete, quick success")
return NetworkResult.Success(Unit)
} }
return getAuthCredential() return getAuthCredential()
@@ -349,7 +413,101 @@ object BackupRepository {
mediaToDelete = mediaToDelete mediaToDelete = mediaToDelete
) )
} }
.also { Log.i(TAG, "deleteBackupMediaResult: $it") } .map {
SignalDatabase.attachments.clearArchiveData(attachments.map { it.attachmentId })
}
.also { Log.i(TAG, "deleteArchivedMediaResult: $it") }
}
fun debugDeleteAllArchivedMedia(): NetworkResult<Unit> {
val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
return debugGetArchivedMediaState()
.then { archivedMedia ->
val mediaToDelete = archivedMedia
.map {
DeleteArchivedMediaRequest.ArchivedMediaObject(
cdn = it.cdn,
mediaId = it.mediaId
)
}
if (mediaToDelete.isEmpty()) {
Log.i(TAG, "No media to delete, quick success")
NetworkResult.Success(Unit)
} else {
getAuthCredential()
.then { credential ->
api.deleteArchivedMedia(
backupKey = backupKey,
serviceCredential = credential,
mediaToDelete = mediaToDelete
)
}
}
}
.map {
SignalDatabase.attachments.clearAllArchiveData()
}
.also { Log.i(TAG, "debugDeleteAllArchivedMediaResult: $it") }
}
/**
* Retrieve credentials for reading from the backup cdn.
*/
fun getCdnReadCredentials(): NetworkResult<GetArchiveCdnCredentialsResponse> {
val cached = SignalStore.backup().cdnReadCredentials
if (cached != null) {
return NetworkResult.Success(cached)
}
val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
return getAuthCredential()
.then { credential ->
api.getCdnReadCredentials(
backupKey = backupKey,
serviceCredential = credential
)
}
.also {
if (it is NetworkResult.Success) {
SignalStore.backup().cdnReadCredentials = it.result
}
}
.also { Log.i(TAG, "getCdnReadCredentialsResult: $it") }
}
/**
* Retrieves backupDir and mediaDir, preferring cached value if available.
*
* These will only ever change if the backup expires.
*/
fun getCdnBackupDirectories(): NetworkResult<BackupDirectories> {
val cachedBackupDirectory = SignalStore.backup().cachedBackupDirectory
val cachedBackupMediaDirectory = SignalStore.backup().cachedBackupMediaDirectory
if (cachedBackupDirectory != null && cachedBackupMediaDirectory != null) {
return NetworkResult.Success(BackupDirectories(cachedBackupDirectory, cachedBackupMediaDirectory))
}
val api = ApplicationDependencies.getSignalServiceAccountManager().archiveApi
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
return getAuthCredential()
.then { credential ->
api.getBackupInfo(backupKey, credential).map {
BackupDirectories(it.backupDir!!, it.mediaDir!!)
}
}
.also {
if (it is NetworkResult.Success) {
SignalStore.backup().cachedBackupDirectory = it.result.backupDir
SignalStore.backup().cachedBackupMediaDirectory = it.result.mediaDir
}
}
} }
/** /**
@@ -380,15 +538,20 @@ object BackupRepository {
val profileKey: ProfileKey val profileKey: ProfileKey
) )
private fun DatabaseAttachment.toArchiveMediaRequest(backupKey: BackupKey): ArchiveMediaRequest { fun DatabaseAttachment.getMediaName(): MediaName {
val mediaSecrets = backupKey.deriveMediaSecrets(Base64.decode(dataHash!!)) return MediaName.fromDigest(remoteDigest!!)
}
private fun DatabaseAttachment.toArchiveMediaRequest(mediaName: MediaName, backupKey: BackupKey): ArchiveMediaRequest {
val mediaSecrets = backupKey.deriveMediaSecrets(mediaName)
return ArchiveMediaRequest( return ArchiveMediaRequest(
sourceAttachment = ArchiveMediaRequest.SourceAttachment( sourceAttachment = ArchiveMediaRequest.SourceAttachment(
cdn = cdnNumber, cdn = cdn.cdnNumber,
key = remoteLocation!! key = remoteLocation!!
), ),
objectLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(size)).toInt(), objectLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(size)).toInt(),
mediaId = mediaSecrets.id.toString(), mediaId = mediaSecrets.id.encode(),
hmacKey = Base64.encodeWithPadding(mediaSecrets.macKey), hmacKey = Base64.encodeWithPadding(mediaSecrets.macKey),
encryptionKey = Base64.encodeWithPadding(mediaSecrets.cipherKey), encryptionKey = Base64.encodeWithPadding(mediaSecrets.cipherKey),
iv = Base64.encodeWithPadding(mediaSecrets.iv) iv = Base64.encodeWithPadding(mediaSecrets.iv)
@@ -396,12 +559,14 @@ object BackupRepository {
} }
} }
class ExportState(val backupTime: Long) { data class BackupDirectories(val backupDir: String, val mediaDir: String)
class ExportState(val backupTime: Long, val allowMediaBackup: Boolean) {
val recipientIds = HashSet<Long>() val recipientIds = HashSet<Long>()
val threadIds = HashSet<Long>() val threadIds = HashSet<Long>()
} }
class BackupState { class BackupState(val backupKey: BackupKey) {
val backupToLocalRecipientId = HashMap<Long, RecipientId>() val backupToLocalRecipientId = HashMap<Long, RecipientId>()
val chatIdToLocalThreadId = HashMap<Long, Long>() val chatIdToLocalThreadId = HashMap<Long, Long>()
val chatIdToLocalRecipientId = HashMap<Long, RecipientId>() val chatIdToLocalRecipientId = HashMap<Long, RecipientId>()
@@ -0,0 +1,46 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.backup.v2
import org.signal.core.util.concurrent.SignalExecutors
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.database.AttachmentTable
import org.thoughtcrime.securesms.database.model.MessageRecord
import org.thoughtcrime.securesms.database.model.MmsMessageRecord
import org.thoughtcrime.securesms.jobs.RestoreAttachmentJob
/**
* Responsible for managing logic around restore prioritization
*/
object BackupRestoreManager {
private val reprioritizedAttachments: HashSet<AttachmentId> = HashSet()
/**
* Raise priority of all attachments for the included message records.
*
* This is so we can make certain attachments get downloaded more quickly
*/
fun prioritizeAttachmentsIfNeeded(messageRecords: List<MessageRecord>) {
SignalExecutors.BOUNDED.execute {
synchronized(this) {
val restoringAttachments: List<AttachmentId> = messageRecords
.mapNotNull { (it as? MmsMessageRecord?)?.slideDeck?.slides }
.flatten()
.mapNotNull { it.asAttachment() as? DatabaseAttachment }
.filter { it.transferState == AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS && !reprioritizedAttachments.contains(it.attachmentId) }
.map { it.attachmentId }
reprioritizedAttachments += restoringAttachments
if (restoringAttachments.isNotEmpty()) {
RestoreAttachmentJob.modifyPriorities(restoringAttachments.toSet(), 1)
}
}
}
}
}
@@ -0,0 +1,39 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.backup.v2
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.whispersystems.signalservice.api.archive.BatchArchiveMediaResponse
/**
* Result of attempting to batch copy multiple attachments at once with helpers for
* processing the collection of mini-responses.
*/
data class BatchArchiveMediaResult(
private val response: BatchArchiveMediaResponse,
private val mediaIdToAttachmentId: Map<String, AttachmentId>,
private val attachmentIdToMediaName: Map<AttachmentId, String>
) {
val successfulResponses: Sequence<BatchArchiveMediaResponse.BatchArchiveMediaItemResponse>
get() = response
.responses
.asSequence()
.filter { it.status == 200 }
val sourceNotFoundResponses: Sequence<BatchArchiveMediaResponse.BatchArchiveMediaItemResponse>
get() = response
.responses
.asSequence()
.filter { it.status == 410 }
fun mediaIdToAttachmentId(mediaId: String): AttachmentId {
return mediaIdToAttachmentId[mediaId]!!
}
fun attachmentIdToMediaName(attachmentId: AttachmentId): String {
return attachmentIdToMediaName[attachmentId]!!
}
}
@@ -17,7 +17,9 @@ import org.signal.core.util.requireBoolean
import org.signal.core.util.requireInt import org.signal.core.util.requireInt
import org.signal.core.util.requireLong import org.signal.core.util.requireLong
import org.signal.core.util.requireString import org.signal.core.util.requireString
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.DatabaseAttachment import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.backup.v2.BackupRepository.getMediaName
import org.thoughtcrime.securesms.backup.v2.proto.CallChatUpdate import org.thoughtcrime.securesms.backup.v2.proto.CallChatUpdate
import org.thoughtcrime.securesms.backup.v2.proto.ChatItem import org.thoughtcrime.securesms.backup.v2.proto.ChatItem
import org.thoughtcrime.securesms.backup.v2.proto.ChatUpdateMessage import org.thoughtcrime.securesms.backup.v2.proto.ChatUpdateMessage
@@ -36,6 +38,7 @@ import org.thoughtcrime.securesms.backup.v2.proto.SimpleChatUpdate
import org.thoughtcrime.securesms.backup.v2.proto.StandardMessage import org.thoughtcrime.securesms.backup.v2.proto.StandardMessage
import org.thoughtcrime.securesms.backup.v2.proto.Text import org.thoughtcrime.securesms.backup.v2.proto.Text
import org.thoughtcrime.securesms.backup.v2.proto.ThreadMergeChatUpdate import org.thoughtcrime.securesms.backup.v2.proto.ThreadMergeChatUpdate
import org.thoughtcrime.securesms.database.AttachmentTable
import org.thoughtcrime.securesms.database.GroupReceiptTable import org.thoughtcrime.securesms.database.GroupReceiptTable
import org.thoughtcrime.securesms.database.MessageTable import org.thoughtcrime.securesms.database.MessageTable
import org.thoughtcrime.securesms.database.MessageTypes import org.thoughtcrime.securesms.database.MessageTypes
@@ -73,7 +76,7 @@ import org.thoughtcrime.securesms.backup.v2.proto.BodyRange as BackupBodyRange
* *
* All of this complexity is hidden from the user -- they just get a normal iterator interface. * All of this complexity is hidden from the user -- they just get a normal iterator interface.
*/ */
class ChatItemExportIterator(private val cursor: Cursor, private val batchSize: Int) : Iterator<ChatItem>, Closeable { class ChatItemExportIterator(private val cursor: Cursor, private val batchSize: Int, private val archiveMedia: Boolean) : Iterator<ChatItem>, Closeable {
companion object { companion object {
private val TAG = Log.tag(ChatItemExportIterator::class.java) private val TAG = Log.tag(ChatItemExportIterator::class.java)
@@ -139,6 +142,7 @@ class ChatItemExportIterator(private val cursor: Cursor, private val batchSize:
builder.expiresInMs = null builder.expiresInMs = null
} }
MessageTypes.isProfileChange(record.type) -> { MessageTypes.isProfileChange(record.type) -> {
if (record.body == null) continue
builder.updateMessage = ChatUpdateMessage( builder.updateMessage = ChatUpdateMessage(
profileChange = try { profileChange = try {
val decoded: ByteArray = Base64.decode(record.body!!) val decoded: ByteArray = Base64.decode(record.body!!)
@@ -354,24 +358,46 @@ class ChatItemExportIterator(private val cursor: Cursor, private val batchSize:
} }
private fun DatabaseAttachment.toBackupAttachment(): MessageAttachment { private fun DatabaseAttachment.toBackupAttachment(): MessageAttachment {
val builder = FilePointer.Builder()
builder.contentType = contentType
builder.incrementalMac = incrementalDigest?.toByteString()
builder.incrementalMacChunkSize = incrementalMacChunkSize
builder.fileName = fileName
builder.width = width
builder.height = height
builder.caption = caption
builder.blurHash = blurHash?.hash
if (remoteKey.isNullOrBlank() || remoteDigest == null || size == 0L) {
builder.invalidAttachmentLocator = FilePointer.InvalidAttachmentLocator()
} else {
if (archiveMedia) {
builder.backupLocator = FilePointer.BackupLocator(
mediaName = archiveMediaName ?: this.getMediaName().toString(),
cdnNumber = if (archiveMediaName != null) archiveCdn else Cdn.CDN_3.cdnNumber, // TODO (clark): Update when new proto with optional cdn is landed
key = decode(remoteKey).toByteString(),
size = this.size.toInt(),
digest = remoteDigest.toByteString()
)
} else {
if (remoteLocation.isNullOrBlank()) {
builder.invalidAttachmentLocator = FilePointer.InvalidAttachmentLocator()
} else {
builder.attachmentLocator = FilePointer.AttachmentLocator(
cdnKey = this.remoteLocation,
cdnNumber = this.cdn.cdnNumber,
uploadTimestamp = this.uploadTimestamp,
key = decode(remoteKey).toByteString(),
size = this.size.toInt(),
digest = remoteDigest.toByteString()
)
}
}
}
return MessageAttachment( return MessageAttachment(
pointer = FilePointer( pointer = builder.build(),
attachmentLocator = FilePointer.AttachmentLocator( wasDownloaded = this.transferState == AttachmentTable.TRANSFER_PROGRESS_DONE || this.transferState == AttachmentTable.TRANSFER_NEEDS_RESTORE,
cdnKey = this.remoteLocation ?: "", flag = if (voiceNote) MessageAttachment.Flag.VOICE_MESSAGE else if (videoGif) MessageAttachment.Flag.GIF else if (borderless) MessageAttachment.Flag.BORDERLESS else MessageAttachment.Flag.NONE
cdnNumber = this.cdnNumber,
uploadTimestamp = this.uploadTimestamp
),
key = if (remoteKey != null) decode(remoteKey).toByteString() else null,
contentType = this.contentType,
size = this.size.toInt(),
incrementalMac = this.incrementalDigest?.toByteString(),
incrementalMacChunkSize = this.incrementalMacChunkSize,
fileName = this.fileName,
width = this.width,
height = this.height,
caption = this.caption,
blurHash = this.blurHash?.hash
)
) )
} }
@@ -13,8 +13,11 @@ import org.signal.core.util.logging.Log
import org.signal.core.util.orNull import org.signal.core.util.orNull
import org.signal.core.util.requireLong import org.signal.core.util.requireLong
import org.signal.core.util.toInt import org.signal.core.util.toInt
import org.thoughtcrime.securesms.attachments.ArchivedAttachment
import org.thoughtcrime.securesms.attachments.Attachment import org.thoughtcrime.securesms.attachments.Attachment
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.PointerAttachment import org.thoughtcrime.securesms.attachments.PointerAttachment
import org.thoughtcrime.securesms.attachments.TombstoneAttachment
import org.thoughtcrime.securesms.backup.v2.BackupState import org.thoughtcrime.securesms.backup.v2.BackupState
import org.thoughtcrime.securesms.backup.v2.proto.BodyRange import org.thoughtcrime.securesms.backup.v2.proto.BodyRange
import org.thoughtcrime.securesms.backup.v2.proto.ChatItem import org.thoughtcrime.securesms.backup.v2.proto.ChatItem
@@ -26,6 +29,7 @@ import org.thoughtcrime.securesms.backup.v2.proto.Reaction
import org.thoughtcrime.securesms.backup.v2.proto.SendStatus import org.thoughtcrime.securesms.backup.v2.proto.SendStatus
import org.thoughtcrime.securesms.backup.v2.proto.SimpleChatUpdate import org.thoughtcrime.securesms.backup.v2.proto.SimpleChatUpdate
import org.thoughtcrime.securesms.backup.v2.proto.StandardMessage import org.thoughtcrime.securesms.backup.v2.proto.StandardMessage
import org.thoughtcrime.securesms.database.AttachmentTable
import org.thoughtcrime.securesms.database.CallTable import org.thoughtcrime.securesms.database.CallTable
import org.thoughtcrime.securesms.database.GroupReceiptTable import org.thoughtcrime.securesms.database.GroupReceiptTable
import org.thoughtcrime.securesms.database.MessageTable import org.thoughtcrime.securesms.database.MessageTable
@@ -48,11 +52,12 @@ import org.thoughtcrime.securesms.mms.QuoteModel
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.recipients.RecipientId import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.util.JsonUtils import org.thoughtcrime.securesms.util.JsonUtils
import org.whispersystems.signalservice.api.backup.MediaName
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage
import org.whispersystems.signalservice.api.push.ServiceId import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.util.UuidUtil import org.whispersystems.signalservice.api.util.UuidUtil
import org.whispersystems.signalservice.internal.push.DataMessage
import java.util.Optional import java.util.Optional
/** /**
@@ -570,12 +575,12 @@ class ChatItemImportInserter(
pointer.attachmentLocator.cdnNumber, pointer.attachmentLocator.cdnNumber,
SignalServiceAttachmentRemoteId.from(pointer.attachmentLocator.cdnKey), SignalServiceAttachmentRemoteId.from(pointer.attachmentLocator.cdnKey),
contentType, contentType,
pointer.key?.toByteArray(), pointer.attachmentLocator.key.toByteArray(),
Optional.ofNullable(pointer.size), Optional.ofNullable(pointer.attachmentLocator.size),
Optional.empty(), Optional.empty(),
pointer.width ?: 0, pointer.width ?: 0,
pointer.height ?: 0, pointer.height ?: 0,
Optional.empty(), Optional.ofNullable(pointer.attachmentLocator.digest.toByteArray()),
Optional.ofNullable(pointer.incrementalMac?.toByteArray()), Optional.ofNullable(pointer.incrementalMac?.toByteArray()),
pointer.incrementalMacChunkSize ?: 0, pointer.incrementalMacChunkSize ?: 0,
Optional.ofNullable(fileName), Optional.ofNullable(fileName),
@@ -586,14 +591,51 @@ class ChatItemImportInserter(
Optional.ofNullable(pointer.blurHash), Optional.ofNullable(pointer.blurHash),
pointer.attachmentLocator.uploadTimestamp pointer.attachmentLocator.uploadTimestamp
) )
return PointerAttachment.forPointer(Optional.of(signalAttachmentPointer)).orNull() return PointerAttachment.forPointer(
pointer = Optional.of(signalAttachmentPointer),
transferState = if (wasDownloaded) AttachmentTable.TRANSFER_NEEDS_RESTORE else AttachmentTable.TRANSFER_PROGRESS_PENDING
).orNull()
} else if (pointer.invalidAttachmentLocator != null) {
return TombstoneAttachment(
contentType = contentType,
incrementalMac = pointer.incrementalMac?.toByteArray(),
incrementalMacChunkSize = pointer.incrementalMacChunkSize,
width = pointer.width,
height = pointer.height,
caption = pointer.caption,
blurHash = pointer.blurHash,
voiceNote = flag == MessageAttachment.Flag.VOICE_MESSAGE,
borderless = flag == MessageAttachment.Flag.BORDERLESS,
gif = flag == MessageAttachment.Flag.GIF,
quote = false
)
} else if (pointer.backupLocator != null) {
return ArchivedAttachment(
contentType = contentType,
size = pointer.backupLocator.size.toLong(),
cdn = Cdn.fromCdnNumber(pointer.backupLocator.cdnNumber),
cdnKey = pointer.backupLocator.key.toByteArray(),
archiveMediaName = pointer.backupLocator.mediaName,
archiveMediaId = backupState.backupKey.deriveMediaId(MediaName(pointer.backupLocator.mediaName)).encode(),
digest = pointer.backupLocator.digest.toByteArray(),
incrementalMac = pointer.incrementalMac?.toByteArray(),
incrementalMacChunkSize = pointer.incrementalMacChunkSize,
width = pointer.width,
height = pointer.height,
caption = pointer.caption,
blurHash = pointer.blurHash,
voiceNote = flag == MessageAttachment.Flag.VOICE_MESSAGE,
borderless = flag == MessageAttachment.Flag.BORDERLESS,
gif = flag == MessageAttachment.Flag.GIF,
quote = false
)
} }
return null return null
} }
private fun Quote.QuotedAttachment.toLocalAttachment(): Attachment? { private fun Quote.QuotedAttachment.toLocalAttachment(): Attachment? {
return thumbnail?.toLocalAttachment(this.contentType, this.fileName) return thumbnail?.toLocalAttachment(this.contentType, this.fileName)
?: if (this.contentType == null) null else PointerAttachment.forPointer(SignalServiceDataMessage.Quote.QuotedAttachment(contentType = this.contentType!!, fileName = this.fileName, thumbnail = null)).orNull() ?: if (this.contentType == null) null else PointerAttachment.forPointer(quotedAttachment = DataMessage.Quote.QuotedAttachment(contentType = this.contentType, fileName = this.fileName, thumbnail = null)).orNull()
} }
private class MessageInsert(val contentValues: ContentValues, val followUp: ((Long) -> Unit)?) private class MessageInsert(val contentValues: ContentValues, val followUp: ((Long) -> Unit)?)
@@ -16,7 +16,7 @@ import java.util.concurrent.TimeUnit
private val TAG = Log.tag(MessageTable::class.java) private val TAG = Log.tag(MessageTable::class.java)
private const val BASE_TYPE = "base_type" private const val BASE_TYPE = "base_type"
fun MessageTable.getMessagesForBackup(backupTime: Long): ChatItemExportIterator { fun MessageTable.getMessagesForBackup(backupTime: Long, archiveMedia: Boolean): ChatItemExportIterator {
val cursor = readableDatabase val cursor = readableDatabase
.select( .select(
MessageTable.ID, MessageTable.ID,
@@ -64,7 +64,7 @@ fun MessageTable.getMessagesForBackup(backupTime: Long): ChatItemExportIterator
.orderBy("${MessageTable.DATE_RECEIVED} ASC") .orderBy("${MessageTable.DATE_RECEIVED} ASC")
.run() .run()
return ChatItemExportIterator(cursor, 100) return ChatItemExportIterator(cursor, 100, archiveMedia)
} }
fun MessageTable.createChatItemInserter(backupState: BackupState): ChatItemImportInserter { fun MessageTable.createChatItemInserter(backupState: BackupState): ChatItemImportInserter {
@@ -19,7 +19,7 @@ object ChatItemBackupProcessor {
val TAG = Log.tag(ChatItemBackupProcessor::class.java) val TAG = Log.tag(ChatItemBackupProcessor::class.java)
fun export(exportState: ExportState, emitter: BackupFrameEmitter) { fun export(exportState: ExportState, emitter: BackupFrameEmitter) {
SignalDatabase.messages.getMessagesForBackup(exportState.backupTime).use { chatItems -> SignalDatabase.messages.getMessagesForBackup(exportState.backupTime, exportState.allowMediaBackup).use { chatItems ->
for (chatItem in chatItems) { for (chatItem in chatItems) {
if (exportState.threadIds.contains(chatItem.chatId)) { if (exportState.threadIds.contains(chatItem.chatId)) {
emitter.emit(Frame(chatItem = chatItem)) emitter.emit(Frame(chatItem = chatItem))
@@ -26,6 +26,7 @@ import androidx.compose.material3.Surface
import androidx.compose.material3.Switch import androidx.compose.material3.Switch
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.SideEffect
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
@@ -80,6 +81,15 @@ class MessageBackupsTestRestoreActivity : BaseActivity() {
.fillMaxSize() .fillMaxSize()
.padding(16.dp) .padding(16.dp)
) { ) {
Buttons.LargePrimary(
onClick = this@MessageBackupsTestRestoreActivity::restoreFromServer,
enabled = !state.importState.inProgress
) {
Text("Restore")
}
Spacer(modifier = Modifier.height(8.dp))
Row( Row(
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically
) { ) {
@@ -120,9 +130,20 @@ class MessageBackupsTestRestoreActivity : BaseActivity() {
} }
} }
} }
if (state.importState == MessageBackupsTestRestoreViewModel.ImportState.RESTORED) {
SideEffect {
RegistrationUtil.maybeMarkRegistrationComplete()
ApplicationDependencies.getJobManager().add(ProfileUploadJob())
startActivity(MainActivity.clearTop(this))
}
}
} }
} }
private fun restoreFromServer() {
viewModel.restore()
}
private fun continueRegistration() { private fun continueRegistration() {
if (Recipient.self().profileName.isEmpty || !AvatarHelper.hasAvatar(this, Recipient.self().id)) { if (Recipient.self().profileName.isEmpty || !AvatarHelper.hasAvatar(this, Recipient.self().id)) {
val main = MainActivity.clearTop(this) val main = MainActivity.clearTop(this)
@@ -15,8 +15,12 @@ import io.reactivex.rxjava3.disposables.CompositeDisposable
import io.reactivex.rxjava3.kotlin.plusAssign import io.reactivex.rxjava3.kotlin.plusAssign
import io.reactivex.rxjava3.kotlin.subscribeBy import io.reactivex.rxjava3.kotlin.subscribeBy
import io.reactivex.rxjava3.schedulers.Schedulers import io.reactivex.rxjava3.schedulers.Schedulers
import org.signal.core.util.orNull
import org.signal.libsignal.zkgroup.profiles.ProfileKey import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.thoughtcrime.securesms.backup.v2.BackupRepository import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.jobmanager.JobTracker
import org.thoughtcrime.securesms.jobs.BackupRestoreJob
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import java.io.InputStream import java.io.InputStream
@@ -40,6 +44,19 @@ class MessageBackupsTestRestoreViewModel : ViewModel() {
} }
} }
fun restore() {
_state.value = _state.value.copy(importState = ImportState.IN_PROGRESS)
disposables += Single.fromCallable {
val jobState = ApplicationDependencies.getJobManager().runSynchronously(BackupRestoreJob(), 120_000)
jobState.orNull() == JobTracker.JobState.SUCCESS
}
.subscribeOn(Schedulers.io())
.observeOn(AndroidSchedulers.mainThread())
.subscribeBy {
_state.value = _state.value.copy(importState = ImportState.RESTORED)
}
}
fun onPlaintextToggled() { fun onPlaintextToggled() {
_state.value = _state.value.copy(plaintext = !_state.value.plaintext) _state.value = _state.value.copy(plaintext = !_state.value.plaintext)
} }
@@ -54,6 +71,6 @@ class MessageBackupsTestRestoreViewModel : ViewModel() {
) )
enum class ImportState(val inProgress: Boolean = false) { enum class ImportState(val inProgress: Boolean = false) {
NONE, IN_PROGRESS(true) NONE, IN_PROGRESS(true), RESTORED
} }
} }
@@ -29,6 +29,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.Checkbox import androidx.compose.material3.Checkbox
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.DropdownMenu
import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.SnackbarHostState import androidx.compose.material3.SnackbarHostState
@@ -37,6 +42,8 @@ import androidx.compose.material3.Switch
import androidx.compose.material3.Tab import androidx.compose.material3.Tab
import androidx.compose.material3.TabRow import androidx.compose.material3.TabRow
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBar
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
@@ -46,10 +53,12 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.fragment.app.viewModels import androidx.fragment.app.viewModels
import androidx.navigation.fragment.findNavController
import org.signal.core.ui.Buttons import org.signal.core.ui.Buttons
import org.signal.core.ui.Dividers import org.signal.core.ui.Dividers
import org.signal.core.ui.Snackbars import org.signal.core.ui.Snackbars
@@ -57,10 +66,13 @@ import org.signal.core.ui.theme.SignalTheme
import org.signal.core.util.bytes import org.signal.core.util.bytes
import org.signal.core.util.getLength import org.signal.core.util.getLength
import org.signal.core.util.roundedString import org.signal.core.util.roundedString
import org.thoughtcrime.securesms.R
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.BackupState import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.BackupState
import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.BackupUploadState import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.BackupUploadState
import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.ScreenState import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.ScreenState
import org.thoughtcrime.securesms.compose.ComposeFragment import org.thoughtcrime.securesms.compose.ComposeFragment
import org.thoughtcrime.securesms.keyvalue.SignalStore
class InternalBackupPlaygroundFragment : ComposeFragment() { class InternalBackupPlaygroundFragment : ComposeFragment() {
@@ -114,6 +126,8 @@ class InternalBackupPlaygroundFragment : ComposeFragment() {
} }
Tabs( Tabs(
onBack = { findNavController().popBackStack() },
onDeleteAllArchivedMedia = { viewModel.deleteAllArchivedMedia() },
mainContent = { mainContent = {
Screen( Screen(
state = state, state = state,
@@ -149,25 +163,32 @@ class InternalBackupPlaygroundFragment : ComposeFragment() {
} }
validateFileLauncher.launch(intent) validateFileLauncher.launch(intent)
} },
onTriggerBackupJobClicked = { viewModel.triggerBackupJob() },
onRestoreFromRemoteClicked = { viewModel.restoreFromRemote() }
) )
}, },
mediaContent = { snackbarHostState -> mediaContent = { snackbarHostState ->
MediaList( MediaList(
enabled = SignalStore.backup().canReadWriteToArchiveCdn,
state = mediaState, state = mediaState,
snackbarHostState = snackbarHostState, snackbarHostState = snackbarHostState,
backupAttachmentMedia = { viewModel.backupAttachmentMedia(it) }, archiveAttachmentMedia = { viewModel.archiveAttachmentMedia(it) },
deleteBackupAttachmentMedia = { viewModel.deleteBackupAttachmentMedia(it) }, deleteArchivedMedia = { viewModel.deleteArchivedMedia(it) },
batchBackupAttachmentMedia = { viewModel.backupAttachmentMedia(it) }, batchArchiveAttachmentMedia = { viewModel.archiveAttachmentMedia(it) },
batchDeleteBackupAttachmentMedia = { viewModel.deleteBackupAttachmentMedia(it) } batchDeleteBackupAttachmentMedia = { viewModel.deleteArchivedMedia(it) },
restoreArchivedMedia = { viewModel.restoreArchivedMedia(it) }
) )
} }
) )
} }
} }
@OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun Tabs( fun Tabs(
onBack: () -> Unit,
onDeleteAllArchivedMedia: () -> Unit,
mainContent: @Composable () -> Unit, mainContent: @Composable () -> Unit,
mediaContent: @Composable (snackbarHostState: SnackbarHostState) -> Unit mediaContent: @Composable (snackbarHostState: SnackbarHostState) -> Unit
) { ) {
@@ -179,13 +200,36 @@ fun Tabs(
Scaffold( Scaffold(
snackbarHost = { Snackbars.Host(snackbarHostState) }, snackbarHost = { Snackbars.Host(snackbarHostState) },
topBar = { topBar = {
TabRow(selectedTabIndex = tabIndex) { Column {
tabs.forEachIndexed { index, tab -> TopAppBar(
Tab( title = {
text = { Text(tab) }, Text("Backup Playground")
selected = index == tabIndex, },
onClick = { tabIndex = index } navigationIcon = {
) IconButton(onClick = onBack) {
Icon(
painter = painterResource(R.drawable.symbol_arrow_left_24),
tint = MaterialTheme.colorScheme.onSurface,
contentDescription = null
)
}
},
actions = {
if (tabIndex == 1 && SignalStore.backup().canReadWriteToArchiveCdn) {
TextButton(onClick = onDeleteAllArchivedMedia) {
Text(text = "Delete All")
}
}
}
)
TabRow(selectedTabIndex = tabIndex) {
tabs.forEachIndexed { index, tab ->
Tab(
text = { Text(tab) },
selected = index == tabIndex,
onClick = { tabIndex = index }
)
}
} }
} }
} }
@@ -209,7 +253,9 @@ fun Screen(
onSaveToDiskClicked: () -> Unit = {}, onSaveToDiskClicked: () -> Unit = {},
onValidateFileClicked: () -> Unit = {}, onValidateFileClicked: () -> Unit = {},
onUploadToRemoteClicked: () -> Unit = {}, onUploadToRemoteClicked: () -> Unit = {},
onCheckRemoteBackupStateClicked: () -> Unit = {} onCheckRemoteBackupStateClicked: () -> Unit = {},
onTriggerBackupJobClicked: () -> Unit = {},
onRestoreFromRemoteClicked: () -> Unit = {}
) { ) {
Surface { Surface {
Column( Column(
@@ -239,6 +285,13 @@ fun Screen(
Text("Export") Text("Export")
} }
Buttons.LargePrimary(
onClick = onTriggerBackupJobClicked,
enabled = !state.backupState.inProgress
) {
Text("Trigger Backup Job")
}
Dividers.Default() Dividers.Default()
Buttons.LargeTonal( Buttons.LargeTonal(
@@ -280,6 +333,10 @@ fun Screen(
} }
} }
BackupState.BACKUP_JOB_DONE -> {
StateLabel("Backup complete and uploaded")
}
BackupState.IMPORT_IN_PROGRESS -> { BackupState.IMPORT_IN_PROGRESS -> {
StateLabel("Import in progress...") StateLabel("Import in progress...")
} }
@@ -324,6 +381,10 @@ fun Screen(
Spacer(modifier = Modifier.height(8.dp)) Spacer(modifier = Modifier.height(8.dp))
Buttons.LargePrimary(onClick = onRestoreFromRemoteClicked) {
Text("Restore from remote")
}
when (state.uploadState) { when (state.uploadState) {
BackupUploadState.NONE -> { BackupUploadState.NONE -> {
StateLabel("") StateLabel("")
@@ -357,13 +418,24 @@ private fun StateLabel(text: String) {
@OptIn(ExperimentalFoundationApi::class) @OptIn(ExperimentalFoundationApi::class)
@Composable @Composable
fun MediaList( fun MediaList(
enabled: Boolean,
state: InternalBackupPlaygroundViewModel.MediaState, state: InternalBackupPlaygroundViewModel.MediaState,
snackbarHostState: SnackbarHostState, snackbarHostState: SnackbarHostState,
backupAttachmentMedia: (InternalBackupPlaygroundViewModel.BackupAttachment) -> Unit, archiveAttachmentMedia: (InternalBackupPlaygroundViewModel.BackupAttachment) -> Unit,
deleteBackupAttachmentMedia: (InternalBackupPlaygroundViewModel.BackupAttachment) -> Unit, deleteArchivedMedia: (InternalBackupPlaygroundViewModel.BackupAttachment) -> Unit,
batchBackupAttachmentMedia: (Set<String>) -> Unit, batchArchiveAttachmentMedia: (Set<AttachmentId>) -> Unit,
batchDeleteBackupAttachmentMedia: (Set<String>) -> Unit batchDeleteBackupAttachmentMedia: (Set<AttachmentId>) -> Unit,
restoreArchivedMedia: (InternalBackupPlaygroundViewModel.BackupAttachment) -> Unit
) { ) {
if (!enabled) {
Text(
text = "You do not have read/write to archive cdn enabled via SignalStore.backup()",
modifier = Modifier
.padding(16.dp)
)
return
}
LaunchedEffect(state.error?.id) { LaunchedEffect(state.error?.id) {
state.error?.let { state.error?.let {
snackbarHostState.showSnackbar(it.errorText) snackbarHostState.showSnackbar(it.errorText)
@@ -384,51 +456,88 @@ fun MediaList(
.combinedClickable( .combinedClickable(
onClick = { onClick = {
if (selectionState.selecting) { if (selectionState.selecting) {
selectionState = selectionState.copy(selected = if (selectionState.selected.contains(attachment.mediaId)) selectionState.selected - attachment.mediaId else selectionState.selected + attachment.mediaId) selectionState = selectionState.copy(selected = if (selectionState.selected.contains(attachment.id)) selectionState.selected - attachment.id else selectionState.selected + attachment.id)
} }
}, },
onLongClick = { onLongClick = {
selectionState = if (selectionState.selecting) MediaMultiSelectState() else MediaMultiSelectState(selecting = true, selected = setOf(attachment.mediaId)) selectionState = if (selectionState.selecting) MediaMultiSelectState() else MediaMultiSelectState(selecting = true, selected = setOf(attachment.id))
} }
) )
.padding(horizontal = 16.dp, vertical = 8.dp) .padding(horizontal = 16.dp, vertical = 8.dp)
) { ) {
if (selectionState.selecting) { if (selectionState.selecting) {
Checkbox( Checkbox(
checked = selectionState.selected.contains(attachment.mediaId), checked = selectionState.selected.contains(attachment.id),
onCheckedChange = { selected -> onCheckedChange = { selected ->
selectionState = selectionState.copy(selected = if (selected) selectionState.selected + attachment.mediaId else selectionState.selected - attachment.mediaId) selectionState = selectionState.copy(selected = if (selected) selectionState.selected + attachment.id else selectionState.selected - attachment.id)
} }
) )
} }
Column(modifier = Modifier.weight(1f, true)) { Column(modifier = Modifier.weight(1f, true)) {
Text(text = "Attachment ${attachment.title}") Text(text = attachment.title)
Text(text = "State: ${attachment.state}") Text(text = "State: ${attachment.state}")
} }
if (attachment.state == InternalBackupPlaygroundViewModel.BackupAttachment.State.INIT || if (attachment.state == InternalBackupPlaygroundViewModel.BackupAttachment.State.IN_PROGRESS) {
attachment.state == InternalBackupPlaygroundViewModel.BackupAttachment.State.IN_PROGRESS
) {
CircularProgressIndicator() CircularProgressIndicator()
} else { } else {
Button( Button(
enabled = !selectionState.selecting, enabled = !selectionState.selecting,
onClick = { onClick = {
when (attachment.state) { when (attachment.state) {
InternalBackupPlaygroundViewModel.BackupAttachment.State.LOCAL_ONLY -> backupAttachmentMedia(attachment) InternalBackupPlaygroundViewModel.BackupAttachment.State.ATTACHMENT_CDN,
InternalBackupPlaygroundViewModel.BackupAttachment.State.UPLOADED -> deleteBackupAttachmentMedia(attachment) InternalBackupPlaygroundViewModel.BackupAttachment.State.LOCAL_ONLY -> archiveAttachmentMedia(attachment)
InternalBackupPlaygroundViewModel.BackupAttachment.State.UPLOADED_UNDOWNLOADED,
InternalBackupPlaygroundViewModel.BackupAttachment.State.UPLOADED_FINAL -> selectionState = selectionState.copy(expandedOption = attachment.dbAttachment.attachmentId)
else -> throw AssertionError("Unsupported state: ${attachment.state}") else -> throw AssertionError("Unsupported state: ${attachment.state}")
} }
} }
) { ) {
Text( Text(
text = when (attachment.state) { text = when (attachment.state) {
InternalBackupPlaygroundViewModel.BackupAttachment.State.ATTACHMENT_CDN,
InternalBackupPlaygroundViewModel.BackupAttachment.State.LOCAL_ONLY -> "Backup" InternalBackupPlaygroundViewModel.BackupAttachment.State.LOCAL_ONLY -> "Backup"
InternalBackupPlaygroundViewModel.BackupAttachment.State.UPLOADED -> "Remote Delete"
InternalBackupPlaygroundViewModel.BackupAttachment.State.UPLOADED_UNDOWNLOADED,
InternalBackupPlaygroundViewModel.BackupAttachment.State.UPLOADED_FINAL -> "Options..."
else -> throw AssertionError("Unsupported state: ${attachment.state}") else -> throw AssertionError("Unsupported state: ${attachment.state}")
} }
) )
DropdownMenu(
expanded = attachment.dbAttachment.attachmentId == selectionState.expandedOption,
onDismissRequest = { selectionState = selectionState.copy(expandedOption = null) }
) {
DropdownMenuItem(
text = { Text("Remote Delete") },
onClick = {
selectionState = selectionState.copy(expandedOption = null)
deleteArchivedMedia(attachment)
}
)
DropdownMenuItem(
text = { Text("Pseudo Restore") },
onClick = {
selectionState = selectionState.copy(expandedOption = null)
restoreArchivedMedia(attachment)
}
)
if (attachment.dbAttachment.dataHash != null && attachment.state == InternalBackupPlaygroundViewModel.BackupAttachment.State.UPLOADED_UNDOWNLOADED) {
DropdownMenuItem(
text = { Text("Re-copy with hash") },
onClick = {
selectionState = selectionState.copy(expandedOption = null)
archiveAttachmentMedia(attachment)
}
)
}
}
} }
} }
} }
@@ -451,7 +560,7 @@ fun MediaList(
Text("Cancel") Text("Cancel")
} }
Button(onClick = { Button(onClick = {
batchBackupAttachmentMedia(selectionState.selected) batchArchiveAttachmentMedia(selectionState.selected)
selectionState = MediaMultiSelectState() selectionState = MediaMultiSelectState()
}) { }) {
Text("Backup") Text("Backup")
@@ -469,7 +578,8 @@ fun MediaList(
private data class MediaMultiSelectState( private data class MediaMultiSelectState(
val selecting: Boolean = false, val selecting: Boolean = false,
val selected: Set<String> = emptySet() val selected: Set<AttachmentId> = emptySet(),
val expandedOption: AttachmentId? = null
) )
@Preview(name = "Light Theme", group = "screen", uiMode = Configuration.UI_MODE_NIGHT_NO) @Preview(name = "Light Theme", group = "screen", uiMode = Configuration.UI_MODE_NIGHT_NO)
@@ -10,30 +10,38 @@ import androidx.compose.runtime.State
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import io.reactivex.rxjava3.android.schedulers.AndroidSchedulers import io.reactivex.rxjava3.android.schedulers.AndroidSchedulers
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Single import io.reactivex.rxjava3.core.Single
import io.reactivex.rxjava3.disposables.CompositeDisposable import io.reactivex.rxjava3.disposables.CompositeDisposable
import io.reactivex.rxjava3.kotlin.plusAssign import io.reactivex.rxjava3.kotlin.plusAssign
import io.reactivex.rxjava3.kotlin.subscribeBy import io.reactivex.rxjava3.kotlin.subscribeBy
import io.reactivex.rxjava3.schedulers.Schedulers import io.reactivex.rxjava3.schedulers.Schedulers
import org.signal.core.util.Base64
import org.signal.libsignal.zkgroup.profiles.ProfileKey import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.DatabaseAttachment import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.backup.v2.BackupMetadata import org.thoughtcrime.securesms.backup.v2.BackupMetadata
import org.thoughtcrime.securesms.backup.v2.BackupRepository import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.database.MessageType
import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.jobs.ArchiveAttachmentJob
import org.thoughtcrime.securesms.jobs.AttachmentDownloadJob
import org.thoughtcrime.securesms.jobs.AttachmentUploadJob
import org.thoughtcrime.securesms.jobs.BackupMessagesJob
import org.thoughtcrime.securesms.jobs.BackupRestoreJob
import org.thoughtcrime.securesms.jobs.BackupRestoreMediaJob
import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.mms.IncomingMessage
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.whispersystems.signalservice.api.NetworkResult import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.backup.BackupKey import org.whispersystems.signalservice.api.backup.MediaName
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.InputStream import java.io.InputStream
import java.util.UUID import java.util.UUID
import kotlin.random.Random import kotlin.time.Duration.Companion.seconds
class InternalBackupPlaygroundViewModel : ViewModel() { class InternalBackupPlaygroundViewModel : ViewModel() {
private val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
var backupData: ByteArray? = null var backupData: ByteArray? = null
val disposables = CompositeDisposable() val disposables = CompositeDisposable()
@@ -57,6 +65,17 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
} }
} }
fun triggerBackupJob() {
_state.value = _state.value.copy(backupState = BackupState.EXPORT_IN_PROGRESS)
disposables += Single.fromCallable { ApplicationDependencies.getJobManager().runSynchronously(BackupMessagesJob(), 120_000) }
.subscribeOn(Schedulers.io())
.observeOn(AndroidSchedulers.mainThread())
.subscribeBy {
_state.value = _state.value.copy(backupState = BackupState.BACKUP_JOB_DONE)
}
}
fun import() { fun import() {
backupData?.let { backupData?.let {
_state.value = _state.value.copy(backupState = BackupState.IMPORT_IN_PROGRESS) _state.value = _state.value.copy(backupState = BackupState.IMPORT_IN_PROGRESS)
@@ -68,7 +87,7 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
disposables += Single.fromCallable { BackupRepository.import(it.size.toLong(), { ByteArrayInputStream(it) }, selfData, plaintext = plaintext) } disposables += Single.fromCallable { BackupRepository.import(it.size.toLong(), { ByteArrayInputStream(it) }, selfData, plaintext = plaintext) }
.subscribeOn(Schedulers.io()) .subscribeOn(Schedulers.io())
.observeOn(AndroidSchedulers.mainThread()) .observeOn(AndroidSchedulers.mainThread())
.subscribe { nothing -> .subscribeBy {
backupData = null backupData = null
_state.value = _state.value.copy(backupState = BackupState.NONE) _state.value = _state.value.copy(backupState = BackupState.NONE)
} }
@@ -85,7 +104,7 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
disposables += Single.fromCallable { BackupRepository.import(length, inputStreamFactory, selfData, plaintext = plaintext) } disposables += Single.fromCallable { BackupRepository.import(length, inputStreamFactory, selfData, plaintext = plaintext) }
.subscribeOn(Schedulers.io()) .subscribeOn(Schedulers.io())
.observeOn(AndroidSchedulers.mainThread()) .observeOn(AndroidSchedulers.mainThread())
.subscribe { nothing -> .subscribeBy {
backupData = null backupData = null
_state.value = _state.value.copy(backupState = BackupState.NONE) _state.value = _state.value.copy(backupState = BackupState.NONE)
} }
@@ -98,7 +117,7 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
disposables += Single.fromCallable { BackupRepository.validate(length, inputStreamFactory, selfData) } disposables += Single.fromCallable { BackupRepository.validate(length, inputStreamFactory, selfData) }
.subscribeOn(Schedulers.io()) .subscribeOn(Schedulers.io())
.observeOn(AndroidSchedulers.mainThread()) .observeOn(AndroidSchedulers.mainThread())
.subscribe { nothing -> .subscribeBy {
backupData = null backupData = null
_state.value = _state.value.copy(backupState = BackupState.NONE) _state.value = _state.value.copy(backupState = BackupState.NONE)
} }
@@ -142,47 +161,77 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
} }
} }
fun restoreFromRemote() {
_state.value = _state.value.copy(backupState = BackupState.IMPORT_IN_PROGRESS)
disposables += Single.fromCallable {
ApplicationDependencies
.getJobManager()
.startChain(BackupRestoreJob())
.then(BackupRestoreMediaJob())
.enqueueAndBlockUntilCompletion(120.seconds.inWholeMilliseconds)
}
.subscribeOn(Schedulers.io())
.observeOn(AndroidSchedulers.mainThread())
.subscribeBy {
_state.value = _state.value.copy(backupState = BackupState.NONE)
}
}
fun loadMedia() { fun loadMedia() {
disposables += Single disposables += Single
.fromCallable { SignalDatabase.attachments.debugGetLatestAttachments() } .fromCallable { SignalDatabase.attachments.debugGetLatestAttachments() }
.subscribeOn(Schedulers.io()) .subscribeOn(Schedulers.io())
.observeOn(Schedulers.single()) .observeOn(Schedulers.single())
.subscribeBy { .subscribeBy {
_mediaState.set { update(attachments = it.map { a -> BackupAttachment.from(backupKey, a) }) } _mediaState.set { update(attachments = it.map { a -> BackupAttachment(dbAttachment = a) }) }
} }
}
fun archiveAttachmentMedia(attachments: Set<AttachmentId>) {
disposables += Single disposables += Single
.fromCallable { BackupRepository.debugGetArchivedMediaState() } .fromCallable {
val toArchive = mediaState.value
.attachments
.filter { attachments.contains(it.dbAttachment.attachmentId) }
.map { it.dbAttachment }
BackupRepository.archiveMedia(toArchive)
}
.subscribeOn(Schedulers.io()) .subscribeOn(Schedulers.io())
.observeOn(Schedulers.single()) .observeOn(Schedulers.single())
.doOnSubscribe { _mediaState.set { update(inProgress = inProgressMediaIds + attachments) } }
.doOnTerminate { _mediaState.set { update(inProgress = inProgressMediaIds - attachments) } }
.subscribeBy { result -> .subscribeBy { result ->
when (result) { when (result) {
is NetworkResult.Success -> _mediaState.set { update(archiveStateLoaded = true, backedUpMediaIds = result.result.map { it.mediaId }.toSet()) } is NetworkResult.Success -> {
loadMedia()
result
.result
.sourceNotFoundResponses
.forEach {
reUploadAndArchiveMedia(result.result.mediaIdToAttachmentId(it.mediaId))
}
}
else -> _mediaState.set { copy(error = MediaStateError(errorText = "$result")) } else -> _mediaState.set { copy(error = MediaStateError(errorText = "$result")) }
} }
} }
} }
fun backupAttachmentMedia(mediaIds: Set<String>) { fun archiveAttachmentMedia(attachment: BackupAttachment) {
disposables += Single.fromCallable { mediaIds.mapNotNull { mediaState.value.idToAttachment[it]?.dbAttachment }.toList() } disposables += Single.fromCallable { BackupRepository.archiveMedia(attachment.dbAttachment) }
.map { BackupRepository.archiveMedia(it) }
.subscribeOn(Schedulers.io()) .subscribeOn(Schedulers.io())
.observeOn(Schedulers.single()) .observeOn(Schedulers.single())
.doOnSubscribe { _mediaState.set { update(inProgressMediaIds = inProgressMediaIds + mediaIds) } } .doOnSubscribe { _mediaState.set { update(inProgress = inProgressMediaIds + attachment.dbAttachment.attachmentId) } }
.doOnTerminate { _mediaState.set { update(inProgressMediaIds = inProgressMediaIds - mediaIds) } } .doOnTerminate { _mediaState.set { update(inProgress = inProgressMediaIds - attachment.dbAttachment.attachmentId) } }
.subscribeBy { result -> .subscribeBy { result ->
when (result) { when (result) {
is NetworkResult.Success -> { is NetworkResult.Success -> loadMedia()
val response = result.result is NetworkResult.StatusCodeError -> {
val successes = response.responses.filter { it.status == 200 } if (result.code == 410) {
val failures = response.responses - successes.toSet() reUploadAndArchiveMedia(attachment.id)
} else {
_mediaState.set { _mediaState.set { copy(error = MediaStateError(errorText = "$result")) }
var updated = update(backedUpMediaIds = backedUpMediaIds + successes.map { it.mediaId })
if (failures.isNotEmpty()) {
updated = updated.copy(error = MediaStateError(errorText = failures.toString()))
}
updated
} }
} }
@@ -191,49 +240,107 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
} }
} }
fun backupAttachmentMedia(attachment: BackupAttachment) { private fun reUploadAndArchiveMedia(attachmentId: AttachmentId) {
disposables += Single.fromCallable { BackupRepository.archiveMedia(attachment.dbAttachment) } disposables += Single
.fromCallable {
ApplicationDependencies
.getJobManager()
.startChain(AttachmentUploadJob(attachmentId))
.then(ArchiveAttachmentJob(attachmentId))
.enqueueAndBlockUntilCompletion(15.seconds.inWholeMilliseconds)
}
.subscribeOn(Schedulers.io()) .subscribeOn(Schedulers.io())
.observeOn(Schedulers.single()) .observeOn(Schedulers.single())
.doOnSubscribe { _mediaState.set { update(inProgressMediaIds = inProgressMediaIds + attachment.mediaId) } } .doOnSubscribe { _mediaState.set { update(inProgress = inProgressMediaIds + attachmentId) } }
.doOnTerminate { _mediaState.set { update(inProgressMediaIds = inProgressMediaIds - attachment.mediaId) } } .doOnTerminate { _mediaState.set { update(inProgress = inProgressMediaIds - attachmentId) } }
.subscribeBy { .subscribeBy {
when (it) { if (it.isPresent && it.get().isComplete) {
is NetworkResult.Success -> { loadMedia()
_mediaState.set { update(backedUpMediaIds = backedUpMediaIds + attachment.mediaId) } } else {
} _mediaState.set { copy(error = MediaStateError(errorText = "Reupload slow or failed, try again")) }
else -> _mediaState.set { copy(error = MediaStateError(errorText = "$it")) }
} }
} }
} }
fun deleteBackupAttachmentMedia(mediaIds: Set<String>) { fun deleteArchivedMedia(attachmentIds: Set<AttachmentId>) {
deleteBackupAttachmentMedia(mediaIds.mapNotNull { mediaState.value.idToAttachment[it] }.toList()) deleteArchivedMedia(mediaState.value.attachments.filter { attachmentIds.contains(it.dbAttachment.attachmentId) })
} }
fun deleteBackupAttachmentMedia(attachment: BackupAttachment) { fun deleteArchivedMedia(attachment: BackupAttachment) {
deleteBackupAttachmentMedia(listOf(attachment)) deleteArchivedMedia(listOf(attachment))
} }
private fun deleteBackupAttachmentMedia(attachments: List<BackupAttachment>) { private fun deleteArchivedMedia(attachments: List<BackupAttachment>) {
val ids = attachments.map { it.mediaId }.toSet() val ids = attachments.map { it.dbAttachment.attachmentId }.toSet()
disposables += Single.fromCallable { BackupRepository.deleteArchivedMedia(attachments.map { it.dbAttachment }) } disposables += Single.fromCallable { BackupRepository.deleteArchivedMedia(attachments.map { it.dbAttachment }) }
.subscribeOn(Schedulers.io()) .subscribeOn(Schedulers.io())
.observeOn(Schedulers.single()) .observeOn(Schedulers.single())
.doOnSubscribe { _mediaState.set { update(inProgressMediaIds = inProgressMediaIds + ids) } } .doOnSubscribe { _mediaState.set { update(inProgress = inProgressMediaIds + ids) } }
.doOnTerminate { _mediaState.set { update(inProgressMediaIds = inProgressMediaIds - ids) } } .doOnTerminate { _mediaState.set { update(inProgress = inProgressMediaIds - ids) } }
.subscribeBy { .subscribeBy {
when (it) { when (it) {
is NetworkResult.Success -> { is NetworkResult.Success -> loadMedia()
_mediaState.set { update(backedUpMediaIds = backedUpMediaIds - ids) }
}
else -> _mediaState.set { copy(error = MediaStateError(errorText = "$it")) } else -> _mediaState.set { copy(error = MediaStateError(errorText = "$it")) }
} }
} }
} }
fun deleteAllArchivedMedia() {
disposables += Single
.fromCallable { BackupRepository.debugDeleteAllArchivedMedia() }
.subscribeOn(Schedulers.io())
.observeOn(Schedulers.single())
.subscribeBy { result ->
when (result) {
is NetworkResult.Success -> loadMedia()
else -> _mediaState.set { copy(error = MediaStateError(errorText = "$result")) }
}
}
}
fun restoreArchivedMedia(attachment: BackupAttachment) {
disposables += Completable
.fromCallable {
val recipientId = SignalStore.releaseChannelValues().releaseChannelRecipientId!!
val threadId = SignalDatabase.threads.getOrCreateThreadIdFor(Recipient.resolved(recipientId))
val message = IncomingMessage(
type = MessageType.NORMAL,
from = recipientId,
sentTimeMillis = System.currentTimeMillis(),
serverTimeMillis = System.currentTimeMillis(),
receivedTimeMillis = System.currentTimeMillis(),
body = "Restored from Archive!?",
serverGuid = UUID.randomUUID().toString()
)
val insertMessage = SignalDatabase.messages.insertMessageInbox(message, threadId).get()
SignalDatabase.attachments.debugCopyAttachmentForArchiveRestore(
insertMessage.messageId,
attachment.dbAttachment
)
val archivedAttachment = SignalDatabase.attachments.getAttachmentsForMessage(insertMessage.messageId).first()
ApplicationDependencies.getJobManager().add(
AttachmentDownloadJob(
messageId = insertMessage.messageId,
attachmentId = archivedAttachment.attachmentId,
manual = false,
forceArchiveDownload = true
)
)
}
.subscribeOn(Schedulers.io())
.observeOn(Schedulers.single())
.subscribeBy(
onError = {
_mediaState.set { copy(error = MediaStateError(errorText = "$it")) }
}
)
}
override fun onCleared() { override fun onCleared() {
disposables.clear() disposables.clear()
} }
@@ -246,7 +353,7 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
) )
enum class BackupState(val inProgress: Boolean = false) { enum class BackupState(val inProgress: Boolean = false) {
NONE, EXPORT_IN_PROGRESS(true), EXPORT_DONE, IMPORT_IN_PROGRESS(true) NONE, EXPORT_IN_PROGRESS(true), EXPORT_DONE, BACKUP_JOB_DONE, IMPORT_IN_PROGRESS(true)
} }
enum class BackupUploadState(val inProgress: Boolean = false) { enum class BackupUploadState(val inProgress: Boolean = false) {
@@ -261,67 +368,59 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
} }
data class MediaState( data class MediaState(
val backupStateLoaded: Boolean = false,
val attachments: List<BackupAttachment> = emptyList(), val attachments: List<BackupAttachment> = emptyList(),
val backedUpMediaIds: Set<String> = emptySet(), val inProgressMediaIds: Set<AttachmentId> = emptySet(),
val inProgressMediaIds: Set<String> = emptySet(),
val error: MediaStateError? = null val error: MediaStateError? = null
) { ) {
val idToAttachment: Map<String, BackupAttachment> = attachments.associateBy { it.mediaId }
fun update( fun update(
archiveStateLoaded: Boolean = this.backupStateLoaded,
attachments: List<BackupAttachment> = this.attachments, attachments: List<BackupAttachment> = this.attachments,
backedUpMediaIds: Set<String> = this.backedUpMediaIds, inProgress: Set<AttachmentId> = this.inProgressMediaIds
inProgressMediaIds: Set<String> = this.inProgressMediaIds
): MediaState { ): MediaState {
val updatedAttachments = if (archiveStateLoaded) { val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
attachments.map {
val state = if (inProgressMediaIds.contains(it.mediaId)) {
BackupAttachment.State.IN_PROGRESS
} else if (backedUpMediaIds.contains(it.mediaId)) {
BackupAttachment.State.UPLOADED
} else {
BackupAttachment.State.LOCAL_ONLY
}
it.copy(state = state) val updatedAttachments = attachments.map {
val state = if (inProgress.contains(it.dbAttachment.attachmentId)) {
BackupAttachment.State.IN_PROGRESS
} else if (it.dbAttachment.archiveMediaName != null) {
if (it.dbAttachment.remoteDigest != null) {
val mediaId = backupKey.deriveMediaId(MediaName(it.dbAttachment.archiveMediaName)).encode()
if (it.dbAttachment.archiveMediaId == mediaId) {
BackupAttachment.State.UPLOADED_FINAL
} else {
BackupAttachment.State.UPLOADED_UNDOWNLOADED
}
} else {
BackupAttachment.State.UPLOADED_UNDOWNLOADED
}
} else if (it.dbAttachment.dataHash == null) {
BackupAttachment.State.ATTACHMENT_CDN
} else {
BackupAttachment.State.LOCAL_ONLY
} }
} else {
attachments it.copy(state = state)
} }
return copy( return copy(
backupStateLoaded = archiveStateLoaded, attachments = updatedAttachments
attachments = updatedAttachments,
backedUpMediaIds = backedUpMediaIds
) )
} }
} }
data class BackupAttachment( data class BackupAttachment(
val dbAttachment: DatabaseAttachment, val dbAttachment: DatabaseAttachment,
val state: State = State.INIT, val state: State = State.LOCAL_ONLY
val mediaId: String = Base64.encodeUrlSafeWithPadding(Random.nextBytes(15))
) { ) {
val id: Any = dbAttachment.attachmentId val id: AttachmentId = dbAttachment.attachmentId
val title: String = dbAttachment.attachmentId.toString() val title: String = dbAttachment.attachmentId.toString()
enum class State { enum class State {
INIT, ATTACHMENT_CDN,
LOCAL_ONLY, LOCAL_ONLY,
UPLOADED, UPLOADED_UNDOWNLOADED,
UPLOADED_FINAL,
IN_PROGRESS IN_PROGRESS
} }
companion object {
fun from(backupKey: BackupKey, dbAttachment: DatabaseAttachment): BackupAttachment {
return BackupAttachment(
dbAttachment = dbAttachment,
mediaId = backupKey.deriveMediaId(Base64.decode(dbAttachment.dataHash!!)).toString()
)
}
}
} }
data class MediaStateError( data class MediaStateError(
@@ -2450,7 +2450,8 @@ public final class ConversationItem extends RelativeLayout implements BindableCo
for (Slide slide : slides) { for (Slide slide : slides) {
ApplicationDependencies.getJobManager().add(new AttachmentDownloadJob(messageRecord.getId(), ApplicationDependencies.getJobManager().add(new AttachmentDownloadJob(messageRecord.getId(),
((DatabaseAttachment) slide.asAttachment()).attachmentId, ((DatabaseAttachment) slide.asAttachment()).attachmentId,
true)); true,
false));
} }
} }
} }
@@ -2476,7 +2477,8 @@ public final class ConversationItem extends RelativeLayout implements BindableCo
setup(v, slide); setup(v, slide);
jobManager.add(new AttachmentDownloadJob(messageRecord.getId(), jobManager.add(new AttachmentDownloadJob(messageRecord.getId(),
attachmentId, attachmentId,
true)); true,
false));
jobManager.addListener(queue, (job, jobState) -> { jobManager.addListener(queue, (job, jobState) -> {
if (jobState.isComplete()) { if (jobState.isComplete()) {
cleanup(); cleanup();
@@ -10,6 +10,8 @@ import org.signal.core.util.Stopwatch
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.core.util.toInt import org.signal.core.util.toInt
import org.signal.paging.PagedDataSource import org.signal.paging.PagedDataSource
import org.thoughtcrime.securesms.BuildConfig
import org.thoughtcrime.securesms.backup.v2.BackupRestoreManager
import org.thoughtcrime.securesms.conversation.ConversationData import org.thoughtcrime.securesms.conversation.ConversationData
import org.thoughtcrime.securesms.conversation.ConversationMessage import org.thoughtcrime.securesms.conversation.ConversationMessage
import org.thoughtcrime.securesms.conversation.ConversationMessage.ConversationMessageFactory import org.thoughtcrime.securesms.conversation.ConversationMessage.ConversationMessageFactory
@@ -20,6 +22,7 @@ import org.thoughtcrime.securesms.database.model.InMemoryMessageRecord.Universal
import org.thoughtcrime.securesms.database.model.MessageRecord import org.thoughtcrime.securesms.database.model.MessageRecord
import org.thoughtcrime.securesms.database.model.MmsMessageRecord import org.thoughtcrime.securesms.database.model.MmsMessageRecord
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.messagerequests.MessageRequestRepository import org.thoughtcrime.securesms.messagerequests.MessageRequestRepository
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.util.adapter.mapping.MappingModel import org.thoughtcrime.securesms.util.adapter.mapping.MappingModel
@@ -122,6 +125,11 @@ class ConversationDataSource(
records = MessageDataFetcher.updateModelsWithData(records, extraData).toMutableList() records = MessageDataFetcher.updateModelsWithData(records, extraData).toMutableList()
stopwatch.split("models") stopwatch.split("models")
if (BuildConfig.MESSAGE_BACKUP_RESTORE_ENABLED && SignalStore.backup().restoreState.inProgress) {
BackupRestoreManager.prioritizeAttachmentsIfNeeded(records)
stopwatch.split("restore")
}
val messages = records.map { record -> val messages = records.map { record ->
ConversationMessageFactory.createWithUnresolvedData( ConversationMessageFactory.createWithUnresolvedData(
localContext, localContext,
@@ -52,13 +52,16 @@ import org.signal.core.util.requireInt
import org.signal.core.util.requireLong import org.signal.core.util.requireLong
import org.signal.core.util.requireNonNullBlob import org.signal.core.util.requireNonNullBlob
import org.signal.core.util.requireNonNullString import org.signal.core.util.requireNonNullString
import org.signal.core.util.requireObject
import org.signal.core.util.requireString import org.signal.core.util.requireString
import org.signal.core.util.select import org.signal.core.util.select
import org.signal.core.util.toInt import org.signal.core.util.toInt
import org.signal.core.util.update import org.signal.core.util.update
import org.signal.core.util.withinTransaction import org.signal.core.util.withinTransaction
import org.thoughtcrime.securesms.attachments.ArchivedAttachment
import org.thoughtcrime.securesms.attachments.Attachment import org.thoughtcrime.securesms.attachments.Attachment
import org.thoughtcrime.securesms.attachments.AttachmentId import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.DatabaseAttachment import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.audio.AudioHash import org.thoughtcrime.securesms.audio.AudioHash
import org.thoughtcrime.securesms.blurhash.BlurHash import org.thoughtcrime.securesms.blurhash.BlurHash
@@ -140,6 +143,10 @@ class AttachmentTable(
const val TRANSFORM_PROPERTIES = "transform_properties" const val TRANSFORM_PROPERTIES = "transform_properties"
const val DISPLAY_ORDER = "display_order" const val DISPLAY_ORDER = "display_order"
const val UPLOAD_TIMESTAMP = "upload_timestamp" const val UPLOAD_TIMESTAMP = "upload_timestamp"
const val ARCHIVE_CDN = "archive_cdn"
const val ARCHIVE_MEDIA_NAME = "archive_media_name"
const val ARCHIVE_MEDIA_ID = "archive_media_id"
const val ARCHIVE_TRANSFER_FILE = "archive_transfer_file"
const val ATTACHMENT_JSON_ALIAS = "attachment_json" const val ATTACHMENT_JSON_ALIAS = "attachment_json"
@@ -150,6 +157,8 @@ class AttachmentTable(
const val TRANSFER_PROGRESS_PENDING = 2 const val TRANSFER_PROGRESS_PENDING = 2
const val TRANSFER_PROGRESS_FAILED = 3 const val TRANSFER_PROGRESS_FAILED = 3
const val TRANSFER_PROGRESS_PERMANENT_FAILURE = 4 const val TRANSFER_PROGRESS_PERMANENT_FAILURE = 4
const val TRANSFER_NEEDS_RESTORE = 5
const val TRANSFER_RESTORE_IN_PROGRESS = 6
const val PREUPLOAD_MESSAGE_ID: Long = -8675309 const val PREUPLOAD_MESSAGE_ID: Long = -8675309
private val PROJECTION = arrayOf( private val PROJECTION = arrayOf(
@@ -185,7 +194,11 @@ class AttachmentTable(
DISPLAY_ORDER, DISPLAY_ORDER,
UPLOAD_TIMESTAMP, UPLOAD_TIMESTAMP,
DATA_HASH_START, DATA_HASH_START,
DATA_HASH_END DATA_HASH_END,
ARCHIVE_CDN,
ARCHIVE_MEDIA_NAME,
ARCHIVE_MEDIA_ID,
ARCHIVE_TRANSFER_FILE
) )
const val CREATE_TABLE = """ const val CREATE_TABLE = """
@@ -222,7 +235,11 @@ class AttachmentTable(
$DISPLAY_ORDER INTEGER DEFAULT 0, $DISPLAY_ORDER INTEGER DEFAULT 0,
$UPLOAD_TIMESTAMP INTEGER DEFAULT 0, $UPLOAD_TIMESTAMP INTEGER DEFAULT 0,
$DATA_HASH_START TEXT DEFAULT NULL, $DATA_HASH_START TEXT DEFAULT NULL,
$DATA_HASH_END TEXT DEFAULT NULL $DATA_HASH_END TEXT DEFAULT NULL,
$ARCHIVE_CDN INTEGER DEFAULT 0,
$ARCHIVE_MEDIA_NAME TEXT DEFAULT NULL,
$ARCHIVE_MEDIA_ID TEXT DEFAULT NULL,
$ARCHIVE_TRANSFER_FILE TEXT DEFAULT NULL
) )
""" """
@@ -239,7 +256,6 @@ class AttachmentTable(
val ATTACHMENT_POINTER_REUSE_THRESHOLD = 7.days.inWholeMilliseconds val ATTACHMENT_POINTER_REUSE_THRESHOLD = 7.days.inWholeMilliseconds
@JvmStatic @JvmStatic
@JvmOverloads
@Throws(IOException::class) @Throws(IOException::class)
fun newDataFile(context: Context): File { fun newDataFile(context: Context): File {
val partsDirectory = context.getDir(DIRECTORY, Context.MODE_PRIVATE) val partsDirectory = context.getDir(DIRECTORY, Context.MODE_PRIVATE)
@@ -388,6 +404,27 @@ class AttachmentTable(
.flatten() .flatten()
} }
fun getArchivableAttachments(): Cursor {
return readableDatabase
.select(*PROJECTION)
.from(TABLE_NAME)
.where("$ARCHIVE_MEDIA_ID IS NULL AND $REMOTE_DIGEST IS NOT NULL AND ($TRANSFER_STATE = ? OR $TRANSFER_STATE = ?)", TRANSFER_PROGRESS_DONE.toString(), TRANSFER_NEEDS_RESTORE.toString())
.orderBy("$ID DESC")
.run()
}
fun getRestorableAttachments(batchSize: Int): List<DatabaseAttachment> {
return readableDatabase
.select(*PROJECTION)
.from(TABLE_NAME)
.where("$TRANSFER_STATE = ?", TRANSFER_NEEDS_RESTORE.toString())
.limit(batchSize)
.orderBy("$ID DESC")
.run().readToList {
it.readAttachments()
}.flatten()
}
fun deleteAttachmentsForMessage(mmsId: Long): Boolean { fun deleteAttachmentsForMessage(mmsId: Long): Boolean {
Log.d(TAG, "[deleteAttachmentsForMessage] mmsId: $mmsId") Log.d(TAG, "[deleteAttachmentsForMessage] mmsId: $mmsId")
@@ -679,6 +716,7 @@ class AttachmentTable(
values.put(TRANSFER_STATE, TRANSFER_PROGRESS_DONE) values.put(TRANSFER_STATE, TRANSFER_PROGRESS_DONE)
values.put(TRANSFER_FILE, null as String?) values.put(TRANSFER_FILE, null as String?)
values.put(TRANSFORM_PROPERTIES, TransformProperties.forSkipTransform().serialize()) values.put(TRANSFORM_PROPERTIES, TransformProperties.forSkipTransform().serialize())
values.put(ARCHIVE_TRANSFER_FILE, null as String?)
db.update(TABLE_NAME) db.update(TABLE_NAME)
.values(values) .values(values)
@@ -734,7 +772,7 @@ class AttachmentTable(
val values = contentValuesOf( val values = contentValuesOf(
TRANSFER_STATE to TRANSFER_PROGRESS_DONE, TRANSFER_STATE to TRANSFER_PROGRESS_DONE,
CDN_NUMBER to attachment.cdnNumber, CDN_NUMBER to attachment.cdn.serialize(),
REMOTE_LOCATION to attachment.remoteLocation, REMOTE_LOCATION to attachment.remoteLocation,
REMOTE_DIGEST to attachment.remoteDigest, REMOTE_DIGEST to attachment.remoteDigest,
REMOTE_INCREMENTAL_DIGEST to attachment.incrementalDigest, REMOTE_INCREMENTAL_DIGEST to attachment.incrementalDigest,
@@ -774,7 +812,7 @@ class AttachmentTable(
DATA_SIZE to sourceDataInfo.length, DATA_SIZE to sourceDataInfo.length,
DATA_RANDOM to sourceDataInfo.random, DATA_RANDOM to sourceDataInfo.random,
TRANSFER_STATE to sourceAttachment.transferState, TRANSFER_STATE to sourceAttachment.transferState,
CDN_NUMBER to sourceAttachment.cdnNumber, CDN_NUMBER to sourceAttachment.cdn.serialize(),
REMOTE_LOCATION to sourceAttachment.remoteLocation, REMOTE_LOCATION to sourceAttachment.remoteLocation,
REMOTE_DIGEST to sourceAttachment.remoteDigest, REMOTE_DIGEST to sourceAttachment.remoteDigest,
REMOTE_INCREMENTAL_DIGEST to sourceAttachment.incrementalDigest, REMOTE_INCREMENTAL_DIGEST to sourceAttachment.incrementalDigest,
@@ -865,7 +903,11 @@ class AttachmentTable(
val attachmentId = if (attachment.uri != null) { val attachmentId = if (attachment.uri != null) {
insertAttachmentWithData(mmsId, attachment, attachment.quote) insertAttachmentWithData(mmsId, attachment, attachment.quote)
} else { } else {
insertUndownloadedAttachment(mmsId, attachment, attachment.quote) if (attachment is ArchivedAttachment) {
insertArchivedAttachment(mmsId, attachment, attachment.quote)
} else {
insertUndownloadedAttachment(mmsId, attachment, attachment.quote)
}
} }
insertedAttachments[attachment] = attachmentId insertedAttachments[attachment] = attachmentId
@@ -890,6 +932,75 @@ class AttachmentTable(
return insertedAttachments return insertedAttachments
} }
fun debugCopyAttachmentForArchiveRestore(
mmsId: Long,
attachment: DatabaseAttachment
) {
val copy =
"""
INSERT INTO $TABLE_NAME
(
$MESSAGE_ID,
$CONTENT_TYPE,
$TRANSFER_STATE,
$CDN_NUMBER,
$REMOTE_LOCATION,
$REMOTE_DIGEST,
$REMOTE_INCREMENTAL_DIGEST,
$REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE,
$REMOTE_KEY,
$FILE_NAME,
$DATA_SIZE,
$VOICE_NOTE,
$BORDERLESS,
$VIDEO_GIF,
$WIDTH,
$HEIGHT,
$CAPTION,
$UPLOAD_TIMESTAMP,
$BLUR_HASH,
$DATA_SIZE,
$DATA_RANDOM,
$DATA_HASH_START,
$DATA_HASH_END,
$ARCHIVE_MEDIA_ID,
$ARCHIVE_MEDIA_NAME,
$ARCHIVE_CDN
)
SELECT
$mmsId,
$CONTENT_TYPE,
$TRANSFER_PROGRESS_PENDING,
$CDN_NUMBER,
$REMOTE_LOCATION,
$REMOTE_DIGEST,
$REMOTE_INCREMENTAL_DIGEST,
$REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE,
$REMOTE_KEY,
$FILE_NAME,
$DATA_SIZE,
$VOICE_NOTE,
$BORDERLESS,
$VIDEO_GIF,
$WIDTH,
$HEIGHT,
$CAPTION,
${System.currentTimeMillis()},
$BLUR_HASH,
$DATA_SIZE,
$DATA_RANDOM,
$DATA_HASH_START,
$DATA_HASH_END,
"${attachment.archiveMediaId}",
"${attachment.archiveMediaName}",
${attachment.archiveCdn}
FROM $TABLE_NAME
WHERE $ID = ${attachment.attachmentId.id}
"""
writableDatabase.execSQL(copy)
}
/** /**
* Updates the data stored for an existing attachment. This happens after transformations, like transcoding. * Updates the data stored for an existing attachment. This happens after transformations, like transcoding.
*/ */
@@ -956,6 +1067,24 @@ class AttachmentTable(
return transferFile return transferFile
} }
@Throws(IOException::class)
fun getOrCreateArchiveTransferFile(attachmentId: AttachmentId): File {
val existing = getArchiveTransferFile(writableDatabase, attachmentId)
if (existing != null) {
return existing
}
val transferFile = newTransferFile()
writableDatabase
.update(TABLE_NAME)
.values(ARCHIVE_TRANSFER_FILE to transferFile.absolutePath)
.where("$ID = ?", attachmentId.id)
.run()
return transferFile
}
fun getDataFileInfo(attachmentId: AttachmentId): DataFileInfo? { fun getDataFileInfo(attachmentId: AttachmentId): DataFileInfo? {
return readableDatabase return readableDatabase
.select(ID, DATA_FILE, DATA_SIZE, DATA_RANDOM, DATA_HASH_START, DATA_HASH_END, TRANSFORM_PROPERTIES, UPLOAD_TIMESTAMP) .select(ID, DATA_FILE, DATA_SIZE, DATA_RANDOM, DATA_HASH_START, DATA_HASH_END, TRANSFORM_PROPERTIES, UPLOAD_TIMESTAMP)
@@ -1087,7 +1216,7 @@ class AttachmentTable(
transferProgress = jsonObject.getInt(TRANSFER_STATE), transferProgress = jsonObject.getInt(TRANSFER_STATE),
size = jsonObject.getLong(DATA_SIZE), size = jsonObject.getLong(DATA_SIZE),
fileName = jsonObject.getString(FILE_NAME), fileName = jsonObject.getString(FILE_NAME),
cdnNumber = jsonObject.getInt(CDN_NUMBER), cdn = Cdn.deserialize(jsonObject.getInt(CDN_NUMBER)),
location = jsonObject.getString(REMOTE_LOCATION), location = jsonObject.getString(REMOTE_LOCATION),
key = jsonObject.getString(REMOTE_KEY), key = jsonObject.getString(REMOTE_KEY),
digest = null, digest = null,
@@ -1116,7 +1245,10 @@ class AttachmentTable(
transformProperties = TransformProperties.parse(jsonObject.getString(TRANSFORM_PROPERTIES)), transformProperties = TransformProperties.parse(jsonObject.getString(TRANSFORM_PROPERTIES)),
displayOrder = jsonObject.getInt(DISPLAY_ORDER), displayOrder = jsonObject.getInt(DISPLAY_ORDER),
uploadTimestamp = jsonObject.getLong(UPLOAD_TIMESTAMP), uploadTimestamp = jsonObject.getLong(UPLOAD_TIMESTAMP),
dataHash = jsonObject.getString(DATA_HASH_END) dataHash = jsonObject.getString(DATA_HASH_END),
archiveCdn = jsonObject.getInt(ARCHIVE_CDN),
archiveMediaName = jsonObject.getString(ARCHIVE_MEDIA_NAME),
archiveMediaId = jsonObject.getString(ARCHIVE_MEDIA_ID)
) )
} }
} }
@@ -1156,6 +1288,45 @@ class AttachmentTable(
return readableDatabase.rawQuery(query, null) return readableDatabase.rawQuery(query, null)
} }
fun setArchiveData(attachmentId: AttachmentId, archiveCdn: Int, archiveMediaName: String, archiveMediaId: String) {
writableDatabase
.update(TABLE_NAME)
.values(
ARCHIVE_CDN to archiveCdn,
ARCHIVE_MEDIA_ID to archiveMediaId,
ARCHIVE_MEDIA_NAME to archiveMediaName
)
.where("$ID = ?", attachmentId.id)
.run()
}
fun clearArchiveData(attachmentIds: List<AttachmentId>) {
SqlUtil.buildCollectionQuery(ID, attachmentIds.map { it.id })
.forEach { query ->
writableDatabase
.update(TABLE_NAME)
.values(
ARCHIVE_CDN to 0,
ARCHIVE_MEDIA_ID to null,
ARCHIVE_MEDIA_NAME to null
)
.where(query.where, query.whereArgs)
.run()
}
}
fun clearAllArchiveData() {
writableDatabase
.update(TABLE_NAME)
.values(
ARCHIVE_CDN to 0,
ARCHIVE_MEDIA_ID to null,
ARCHIVE_MEDIA_NAME to null
)
.where("$ARCHIVE_CDN > 0 OR $ARCHIVE_MEDIA_ID IS NOT NULL OR $ARCHIVE_MEDIA_NAME IS NOT NULL")
.run()
}
/** /**
* Deletes the data file if there's no strong references to other attachments. * Deletes the data file if there's no strong references to other attachments.
* If deleted, it will also clear all weak references (i.e. quotes) of the attachment. * If deleted, it will also clear all weak references (i.e. quotes) of the attachment.
@@ -1338,7 +1509,7 @@ class AttachmentTable(
put(MESSAGE_ID, messageId) put(MESSAGE_ID, messageId)
put(CONTENT_TYPE, attachment.contentType) put(CONTENT_TYPE, attachment.contentType)
put(TRANSFER_STATE, attachment.transferState) put(TRANSFER_STATE, attachment.transferState)
put(CDN_NUMBER, attachment.cdnNumber) put(CDN_NUMBER, attachment.cdn.serialize())
put(REMOTE_LOCATION, attachment.remoteLocation) put(REMOTE_LOCATION, attachment.remoteLocation)
put(REMOTE_DIGEST, attachment.remoteDigest) put(REMOTE_DIGEST, attachment.remoteDigest)
put(REMOTE_INCREMENTAL_DIGEST, attachment.incrementalDigest) put(REMOTE_INCREMENTAL_DIGEST, attachment.incrementalDigest)
@@ -1373,6 +1544,59 @@ class AttachmentTable(
return attachmentId return attachmentId
} }
/**
* Attachments need records in the database even if they haven't been downloaded yet. That allows us to store the info we need to download it, what message
* it's associated with, etc. We treat this case separately from attachments with data (see [insertAttachmentWithData]) because it's much simpler,
* and splitting the two use cases makes the code easier to understand.
*
* Callers are expected to later call [finalizeAttachmentAfterDownload] once they have downloaded the data for this attachment.
*/
@Throws(MmsException::class)
private fun insertArchivedAttachment(messageId: Long, attachment: ArchivedAttachment, quote: Boolean): AttachmentId {
Log.d(TAG, "[insertAttachment] Inserting attachment for messageId $messageId.")
val attachmentId: AttachmentId = writableDatabase.withinTransaction { db ->
val contentValues = ContentValues().apply {
put(MESSAGE_ID, messageId)
put(CONTENT_TYPE, attachment.contentType)
put(TRANSFER_STATE, attachment.transferState)
put(CDN_NUMBER, attachment.cdn.serialize())
put(REMOTE_LOCATION, attachment.remoteLocation)
put(REMOTE_DIGEST, attachment.remoteDigest)
put(REMOTE_INCREMENTAL_DIGEST, attachment.incrementalDigest)
put(REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE, attachment.incrementalMacChunkSize)
put(REMOTE_KEY, attachment.remoteKey)
put(FILE_NAME, StorageUtil.getCleanFileName(attachment.fileName))
put(DATA_SIZE, attachment.size)
put(FAST_PREFLIGHT_ID, attachment.fastPreflightId)
put(VOICE_NOTE, attachment.voiceNote.toInt())
put(BORDERLESS, attachment.borderless.toInt())
put(VIDEO_GIF, attachment.videoGif.toInt())
put(WIDTH, attachment.width)
put(HEIGHT, attachment.height)
put(QUOTE, quote)
put(CAPTION, attachment.caption)
put(UPLOAD_TIMESTAMP, attachment.uploadTimestamp)
put(ARCHIVE_CDN, attachment.archiveCdn)
put(ARCHIVE_MEDIA_NAME, attachment.archiveMediaName)
put(ARCHIVE_MEDIA_ID, attachment.archiveMediaId)
attachment.stickerLocator?.let { sticker ->
put(STICKER_PACK_ID, sticker.packId)
put(STICKER_PACK_KEY, sticker.packKey)
put(STICKER_ID, sticker.stickerId)
put(STICKER_EMOJI, sticker.emoji)
}
}
val rowId = db.insert(TABLE_NAME, null, contentValues)
AttachmentId(rowId)
}
notifyAttachmentListeners()
return attachmentId
}
/** /**
* Inserts an attachment with existing data. This is likely an outgoing attachment that we're in the process of sending. * Inserts an attachment with existing data. This is likely an outgoing attachment that we're in the process of sending.
*/ */
@@ -1462,7 +1686,7 @@ class AttachmentTable(
contentValues.put(MESSAGE_ID, messageId) contentValues.put(MESSAGE_ID, messageId)
contentValues.put(CONTENT_TYPE, uploadTemplate?.contentType ?: attachment.contentType) contentValues.put(CONTENT_TYPE, uploadTemplate?.contentType ?: attachment.contentType)
contentValues.put(TRANSFER_STATE, attachment.transferState) // Even if we have a template, we let AttachmentUploadJob have the final say so it can re-check and make sure the template is still valid contentValues.put(TRANSFER_STATE, attachment.transferState) // Even if we have a template, we let AttachmentUploadJob have the final say so it can re-check and make sure the template is still valid
contentValues.put(CDN_NUMBER, uploadTemplate?.cdnNumber ?: 0) contentValues.put(CDN_NUMBER, uploadTemplate?.cdn?.serialize() ?: Cdn.CDN_0.serialize())
contentValues.put(REMOTE_LOCATION, uploadTemplate?.remoteLocation) contentValues.put(REMOTE_LOCATION, uploadTemplate?.remoteLocation)
contentValues.put(REMOTE_DIGEST, uploadTemplate?.remoteDigest) contentValues.put(REMOTE_DIGEST, uploadTemplate?.remoteDigest)
contentValues.put(REMOTE_INCREMENTAL_DIGEST, uploadTemplate?.incrementalDigest) contentValues.put(REMOTE_INCREMENTAL_DIGEST, uploadTemplate?.incrementalDigest)
@@ -1520,6 +1744,18 @@ class AttachmentTable(
} }
} }
private fun getArchiveTransferFile(db: SQLiteDatabase, attachmentId: AttachmentId): File? {
return db
.select(ARCHIVE_TRANSFER_FILE)
.from(TABLE_NAME)
.where("$ID = ?", attachmentId.id)
.limit(1)
.run()
.readToSingleObject { cursor ->
cursor.requireString(ARCHIVE_TRANSFER_FILE)?.let { File(it) }
}
}
private fun getAttachment(cursor: Cursor): DatabaseAttachment { private fun getAttachment(cursor: Cursor): DatabaseAttachment {
val contentType = cursor.requireString(CONTENT_TYPE) val contentType = cursor.requireString(CONTENT_TYPE)
@@ -1532,7 +1768,7 @@ class AttachmentTable(
transferProgress = cursor.requireInt(TRANSFER_STATE), transferProgress = cursor.requireInt(TRANSFER_STATE),
size = cursor.requireLong(DATA_SIZE), size = cursor.requireLong(DATA_SIZE),
fileName = cursor.requireString(FILE_NAME), fileName = cursor.requireString(FILE_NAME),
cdnNumber = cursor.requireInt(CDN_NUMBER), cdn = cursor.requireObject(CDN_NUMBER, Cdn.Serializer),
location = cursor.requireString(REMOTE_LOCATION), location = cursor.requireString(REMOTE_LOCATION),
key = cursor.requireString(REMOTE_KEY), key = cursor.requireString(REMOTE_KEY),
digest = cursor.requireBlob(REMOTE_DIGEST), digest = cursor.requireBlob(REMOTE_DIGEST),
@@ -1552,7 +1788,10 @@ class AttachmentTable(
transformProperties = TransformProperties.parse(cursor.requireString(TRANSFORM_PROPERTIES)), transformProperties = TransformProperties.parse(cursor.requireString(TRANSFORM_PROPERTIES)),
displayOrder = cursor.requireInt(DISPLAY_ORDER), displayOrder = cursor.requireInt(DISPLAY_ORDER),
uploadTimestamp = cursor.requireLong(UPLOAD_TIMESTAMP), uploadTimestamp = cursor.requireLong(UPLOAD_TIMESTAMP),
dataHash = cursor.requireString(DATA_HASH_END) dataHash = cursor.requireString(DATA_HASH_END),
archiveCdn = cursor.requireInt(ARCHIVE_CDN),
archiveMediaName = cursor.requireString(ARCHIVE_MEDIA_NAME),
archiveMediaId = cursor.requireString(ARCHIVE_MEDIA_ID)
) )
} }
@@ -1603,7 +1842,7 @@ class AttachmentTable(
return readableDatabase return readableDatabase
.select(*PROJECTION) .select(*PROJECTION)
.from(TABLE_NAME) .from(TABLE_NAME)
.where("$TRANSFER_STATE == $TRANSFER_PROGRESS_DONE AND $REMOTE_LOCATION IS NOT NULL AND $DATA_HASH_END IS NOT NULL") .where("$REMOTE_LOCATION IS NOT NULL AND $REMOTE_KEY IS NOT NULL")
.orderBy("$ID DESC") .orderBy("$ID DESC")
.limit(30) .limit(30)
.run() .run()
@@ -56,6 +56,7 @@ import org.whispersystems.signalservice.api.groupsv2.findRequestingByAci
import org.whispersystems.signalservice.api.groupsv2.toAciList import org.whispersystems.signalservice.api.groupsv2.toAciList
import org.whispersystems.signalservice.api.groupsv2.toAciListWithUnknowns import org.whispersystems.signalservice.api.groupsv2.toAciListWithUnknowns
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import org.whispersystems.signalservice.api.push.DistributionId import org.whispersystems.signalservice.api.push.DistributionId
import org.whispersystems.signalservice.api.push.ServiceId import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.ServiceId.ACI import org.whispersystems.signalservice.api.push.ServiceId.ACI
@@ -746,7 +747,7 @@ class GroupTable(context: Context?, databaseHelper: SignalDatabase?) : DatabaseT
values.put(MMS, groupId.isMms) values.put(MMS, groupId.isMms)
if (avatar != null) { if (avatar != null) {
values.put(AVATAR_ID, avatar.remoteId.v2.get()) values.put(AVATAR_ID, (avatar.remoteId as SignalServiceAttachmentRemoteId.V2).cdnId)
values.put(AVATAR_KEY, avatar.key) values.put(AVATAR_KEY, avatar.key)
values.put(AVATAR_CONTENT_TYPE, avatar.contentType) values.put(AVATAR_CONTENT_TYPE, avatar.contentType)
values.put(AVATAR_DIGEST, avatar.digest.orElse(null)) values.put(AVATAR_DIGEST, avatar.digest.orElse(null))
@@ -822,7 +823,7 @@ class GroupTable(context: Context?, databaseHelper: SignalDatabase?) : DatabaseT
} }
if (avatar != null) { if (avatar != null) {
put(AVATAR_ID, avatar.remoteId.v2.get()) put(AVATAR_ID, (avatar.remoteId as SignalServiceAttachmentRemoteId.V2).cdnId)
put(AVATAR_CONTENT_TYPE, avatar.contentType) put(AVATAR_CONTENT_TYPE, avatar.contentType)
put(AVATAR_KEY, avatar.key) put(AVATAR_KEY, avatar.key)
put(AVATAR_DIGEST, avatar.digest.orElse(null)) put(AVATAR_DIGEST, avatar.digest.orElse(null))
@@ -50,6 +50,9 @@ class MediaTable internal constructor(context: Context?, databaseHelper: SignalD
${AttachmentTable.TABLE_NAME}.${AttachmentTable.REMOTE_INCREMENTAL_DIGEST}, ${AttachmentTable.TABLE_NAME}.${AttachmentTable.REMOTE_INCREMENTAL_DIGEST},
${AttachmentTable.TABLE_NAME}.${AttachmentTable.REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE}, ${AttachmentTable.TABLE_NAME}.${AttachmentTable.REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE},
${AttachmentTable.TABLE_NAME}.${AttachmentTable.DATA_HASH_END}, ${AttachmentTable.TABLE_NAME}.${AttachmentTable.DATA_HASH_END},
${AttachmentTable.TABLE_NAME}.${AttachmentTable.ARCHIVE_CDN},
${AttachmentTable.TABLE_NAME}.${AttachmentTable.ARCHIVE_MEDIA_NAME},
${AttachmentTable.TABLE_NAME}.${AttachmentTable.ARCHIVE_MEDIA_ID},
${MessageTable.TABLE_NAME}.${MessageTable.TYPE}, ${MessageTable.TABLE_NAME}.${MessageTable.TYPE},
${MessageTable.TABLE_NAME}.${MessageTable.DATE_SENT}, ${MessageTable.TABLE_NAME}.${MessageTable.DATE_SENT},
${MessageTable.TABLE_NAME}.${MessageTable.DATE_RECEIVED}, ${MessageTable.TABLE_NAME}.${MessageTable.DATE_RECEIVED},
@@ -376,7 +376,11 @@ open class MessageTable(context: Context?, databaseHelper: SignalDatabase) : Dat
'${AttachmentTable.BLUR_HASH}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.BLUR_HASH}, '${AttachmentTable.BLUR_HASH}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.BLUR_HASH},
'${AttachmentTable.TRANSFORM_PROPERTIES}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.TRANSFORM_PROPERTIES}, '${AttachmentTable.TRANSFORM_PROPERTIES}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.TRANSFORM_PROPERTIES},
'${AttachmentTable.DISPLAY_ORDER}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.DISPLAY_ORDER}, '${AttachmentTable.DISPLAY_ORDER}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.DISPLAY_ORDER},
'${AttachmentTable.UPLOAD_TIMESTAMP}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.UPLOAD_TIMESTAMP} '${AttachmentTable.UPLOAD_TIMESTAMP}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.UPLOAD_TIMESTAMP},
'${AttachmentTable.DATA_HASH_END}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.DATA_HASH_END},
'${AttachmentTable.ARCHIVE_CDN}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.ARCHIVE_CDN},
'${AttachmentTable.ARCHIVE_MEDIA_NAME}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.ARCHIVE_MEDIA_NAME},
'${AttachmentTable.ARCHIVE_MEDIA_ID}', ${AttachmentTable.TABLE_NAME}.${AttachmentTable.ARCHIVE_MEDIA_ID}
) )
) AS ${AttachmentTable.ATTACHMENT_JSON_ALIAS} ) AS ${AttachmentTable.ATTACHMENT_JSON_ALIAS}
""".toSingleLine() """.toSingleLine()
@@ -81,6 +81,7 @@ import org.thoughtcrime.securesms.database.helpers.migration.V220_PreKeyConstrai
import org.thoughtcrime.securesms.database.helpers.migration.V221_AddReadColumnToCallEventsTable import org.thoughtcrime.securesms.database.helpers.migration.V221_AddReadColumnToCallEventsTable
import org.thoughtcrime.securesms.database.helpers.migration.V222_DataHashRefactor import org.thoughtcrime.securesms.database.helpers.migration.V222_DataHashRefactor
import org.thoughtcrime.securesms.database.helpers.migration.V223_AddNicknameAndNoteFieldsToRecipientTable import org.thoughtcrime.securesms.database.helpers.migration.V223_AddNicknameAndNoteFieldsToRecipientTable
import org.thoughtcrime.securesms.database.helpers.migration.V224_AddAttachmentArchiveColumns
/** /**
* Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness. * Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness.
@@ -164,10 +165,11 @@ object SignalDatabaseMigrations {
220 to V220_PreKeyConstraints, 220 to V220_PreKeyConstraints,
221 to V221_AddReadColumnToCallEventsTable, 221 to V221_AddReadColumnToCallEventsTable,
222 to V222_DataHashRefactor, 222 to V222_DataHashRefactor,
223 to V223_AddNicknameAndNoteFieldsToRecipientTable 223 to V223_AddNicknameAndNoteFieldsToRecipientTable,
224 to V224_AddAttachmentArchiveColumns
) )
const val DATABASE_VERSION = 223 const val DATABASE_VERSION = 224
@JvmStatic @JvmStatic
fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) { fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
@@ -0,0 +1,22 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database.helpers.migration
import android.app.Application
import net.zetetic.database.sqlcipher.SQLiteDatabase
/**
* Adds archive_cdn and archive_media to attachment.
*/
@Suppress("ClassName")
object V224_AddAttachmentArchiveColumns : SignalDatabaseMigration {
override fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
db.execSQL("ALTER TABLE attachment ADD COLUMN archive_cdn INTEGER DEFAULT 0")
db.execSQL("ALTER TABLE attachment ADD COLUMN archive_media_name TEXT DEFAULT NULL")
db.execSQL("ALTER TABLE attachment ADD COLUMN archive_media_id TEXT DEFAULT NULL")
db.execSQL("ALTER TABLE attachment ADD COLUMN archive_transfer_file TEXT DEFAULT NULL")
}
}
@@ -0,0 +1,78 @@
package org.thoughtcrime.securesms.jobs
import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.jobmanager.Job
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint
import org.thoughtcrime.securesms.jobs.protos.ArchiveAttachmentJobData
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
import java.io.IOException
import java.util.concurrent.TimeUnit
/**
* Copies and re-encrypts attachments from the attachment cdn to the archive cdn.
*
* Job will fail if the attachment isn't available on the attachment cdn, use [AttachmentUploadJob] to upload first if necessary.
*/
class ArchiveAttachmentJob private constructor(private val attachmentId: AttachmentId, parameters: Parameters) : BaseJob(parameters) {
companion object {
private val TAG = Log.tag(ArchiveAttachmentJob::class.java)
const val KEY = "ArchiveAttachmentJob"
fun enqueueIfPossible(attachmentId: AttachmentId) {
if (!SignalStore.backup().canReadWriteToArchiveCdn) {
return
}
ApplicationDependencies.getJobManager().add(ArchiveAttachmentJob(attachmentId))
}
}
constructor(attachmentId: AttachmentId) : this(
attachmentId = attachmentId,
parameters = Parameters.Builder()
.addConstraint(NetworkConstraint.KEY)
.setLifespan(TimeUnit.DAYS.toMillis(1))
.setMaxAttempts(Parameters.UNLIMITED)
.build()
)
override fun serialize(): ByteArray = ArchiveAttachmentJobData(attachmentId.id).encode()
override fun getFactoryKey(): String = KEY
override fun onRun() {
if (!SignalStore.backup().canReadWriteToArchiveCdn) {
Log.w(TAG, "Do not have permission to read/write to archive cdn")
return
}
val attachment = SignalDatabase.attachments.getAttachment(attachmentId)
if (attachment == null) {
Log.w(TAG, "Unable to find attachment to archive: $attachmentId")
return
}
BackupRepository.archiveMedia(attachment).successOrThrow()
}
override fun onShouldRetry(e: Exception): Boolean {
return e is IOException && e !is NonSuccessfulResponseCodeException
}
override fun onFailure() = Unit
class Factory : Job.Factory<ArchiveAttachmentJob> {
override fun create(parameters: Parameters, serializedData: ByteArray?): ArchiveAttachmentJob {
val jobData = ArchiveAttachmentJobData.ADAPTER.decode(serializedData!!)
return ArchiveAttachmentJob(AttachmentId(jobData.attachmentId), parameters)
}
}
}
@@ -1,350 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.jobs;
import android.text.TextUtils;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.VisibleForTesting;
import org.greenrobot.eventbus.EventBus;
import org.signal.core.util.Hex;
import org.signal.core.util.logging.Log;
import org.signal.libsignal.protocol.InvalidMacException;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.thoughtcrime.securesms.attachments.Attachment;
import org.thoughtcrime.securesms.attachments.AttachmentId;
import org.thoughtcrime.securesms.attachments.DatabaseAttachment;
import org.thoughtcrime.securesms.blurhash.BlurHash;
import org.thoughtcrime.securesms.database.AttachmentTable;
import org.thoughtcrime.securesms.database.SignalDatabase;
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
import org.thoughtcrime.securesms.events.PartProgressEvent;
import org.thoughtcrime.securesms.jobmanager.Job;
import org.thoughtcrime.securesms.jobmanager.JobLogger;
import org.thoughtcrime.securesms.jobmanager.JsonJobData;
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
import org.thoughtcrime.securesms.jobmanager.persistence.JobSpec;
import org.thoughtcrime.securesms.mms.MmsException;
import org.thoughtcrime.securesms.notifications.v2.ConversationId;
import org.thoughtcrime.securesms.releasechannel.ReleaseChannel;
import org.thoughtcrime.securesms.s3.S3;
import org.thoughtcrime.securesms.transport.RetryLaterException;
import org.thoughtcrime.securesms.util.AttachmentUtil;
import org.signal.core.util.Base64;
import org.thoughtcrime.securesms.util.FeatureFlags;
import org.thoughtcrime.securesms.util.Util;
import org.whispersystems.signalservice.api.SignalServiceMessageReceiver;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId;
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException;
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException;
import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException;
import org.whispersystems.signalservice.api.push.exceptions.RangeException;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okio.Okio;
public final class AttachmentDownloadJob extends BaseJob {
public static final String KEY = "AttachmentDownloadJob";
private static final String TAG = Log.tag(AttachmentDownloadJob.class);
private static final String KEY_MESSAGE_ID = "message_id";
private static final String KEY_ATTACHMENT_ID = "part_row_id";
private static final String KEY_MANUAL = "part_manual";
private final long messageId;
private final long attachmentId;
private final boolean manual;
public AttachmentDownloadJob(long messageId, AttachmentId attachmentId, boolean manual) {
this(new Job.Parameters.Builder()
.setQueue(constructQueueString(attachmentId))
.addConstraint(NetworkConstraint.KEY)
.setLifespan(TimeUnit.DAYS.toMillis(1))
.setMaxAttempts(Parameters.UNLIMITED)
.build(),
messageId,
attachmentId,
manual);
}
private AttachmentDownloadJob(@NonNull Job.Parameters parameters, long messageId, AttachmentId attachmentId, boolean manual) {
super(parameters);
this.messageId = messageId;
this.attachmentId = attachmentId.id;
this.manual = manual;
}
@Override
public @Nullable byte[] serialize() {
return new JsonJobData.Builder().putLong(KEY_MESSAGE_ID, messageId)
.putLong(KEY_ATTACHMENT_ID, attachmentId)
.putBoolean(KEY_MANUAL, manual)
.serialize();
}
@Override
public @NonNull String getFactoryKey() {
return KEY;
}
public static String constructQueueString(AttachmentId attachmentId) {
return "AttachmentDownloadJob-" + attachmentId.id;
}
@Override
public void onAdded() {
Log.i(TAG, "onAdded() messageId: " + messageId + " attachmentId: " + attachmentId + " manual: " + manual);
final AttachmentTable database = SignalDatabase.attachments();
final AttachmentId attachmentId = new AttachmentId(this.attachmentId);
final DatabaseAttachment attachment = database.getAttachment(attachmentId);
final boolean pending = attachment != null && attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_DONE
&& attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_PERMANENT_FAILURE;
if (pending && (manual || AttachmentUtil.isAutoDownloadPermitted(context, attachment))) {
Log.i(TAG, "onAdded() Marking attachment progress as 'started'");
database.setTransferState(messageId, attachmentId, AttachmentTable.TRANSFER_PROGRESS_STARTED);
}
}
@Override
public void onRun() throws Exception {
doWork();
if (!SignalDatabase.messages().isStory(messageId)) {
ApplicationDependencies.getMessageNotifier().updateNotification(context, ConversationId.forConversation(0));
}
}
public void doWork() throws IOException, RetryLaterException {
Log.i(TAG, "onRun() messageId: " + messageId + " attachmentId: " + attachmentId + " manual: " + manual);
final AttachmentTable database = SignalDatabase.attachments();
final AttachmentId attachmentId = new AttachmentId(this.attachmentId);
final DatabaseAttachment attachment = database.getAttachment(attachmentId);
if (attachment == null) {
Log.w(TAG, "attachment no longer exists.");
return;
}
if (attachment.isPermanentlyFailed()) {
Log.w(TAG, "Attachment was marked as a permanent failure. Refusing to download.");
return;
}
if (!attachment.isInProgress()) {
Log.w(TAG, "Attachment was already downloaded.");
return;
}
if (!manual && !AttachmentUtil.isAutoDownloadPermitted(context, attachment)) {
Log.w(TAG, "Attachment can't be auto downloaded...");
database.setTransferState(messageId, attachmentId, AttachmentTable.TRANSFER_PROGRESS_PENDING);
return;
}
Log.i(TAG, "Downloading push part " + attachmentId);
database.setTransferState(messageId, attachmentId, AttachmentTable.TRANSFER_PROGRESS_STARTED);
if (attachment.cdnNumber != ReleaseChannel.CDN_NUMBER) {
retrieveAttachment(messageId, attachmentId, attachment);
} else {
retrieveAttachmentForReleaseChannel(messageId, attachmentId, attachment);
}
}
@Override
public void onFailure() {
Log.w(TAG, JobLogger.format(this, "onFailure() messageId: " + messageId + " attachmentId: " + attachmentId + " manual: " + manual));
final AttachmentId attachmentId = new AttachmentId(this.attachmentId);
markFailed(messageId, attachmentId);
}
@Override
protected boolean onShouldRetry(@NonNull Exception exception) {
return exception instanceof PushNetworkException ||
exception instanceof RetryLaterException;
}
private void retrieveAttachment(long messageId,
final AttachmentId attachmentId,
final Attachment attachment)
throws IOException, RetryLaterException
{
long maxReceiveSize = FeatureFlags.maxAttachmentReceiveSizeBytes();
AttachmentTable database = SignalDatabase.attachments();
File attachmentFile = database.getOrCreateTransferFile(attachmentId);
try {
if (attachment.size > maxReceiveSize) {
throw new MmsException("Attachment too large, failing download");
}
SignalServiceMessageReceiver messageReceiver = ApplicationDependencies.getSignalServiceMessageReceiver();
SignalServiceAttachmentPointer pointer = createAttachmentPointer(attachment);
InputStream stream = messageReceiver.retrieveAttachment(pointer,
attachmentFile,
maxReceiveSize,
new SignalServiceAttachment.ProgressListener() {
@Override
public void onAttachmentProgress(long total, long progress) {
EventBus.getDefault().postSticky(new PartProgressEvent(attachment, PartProgressEvent.Type.NETWORK, total, progress));
}
@Override
public boolean shouldCancel() {
return isCanceled();
}
});
database.finalizeAttachmentAfterDownload(messageId, attachmentId, stream);
} catch (RangeException e) {
Log.w(TAG, "Range exception, file size " + attachmentFile.length(), e);
if (attachmentFile.delete()) {
Log.i(TAG, "Deleted temp download file to recover");
throw new RetryLaterException(e);
} else {
throw new IOException("Failed to delete temp download file following range exception");
}
} catch (InvalidPartException | NonSuccessfulResponseCodeException | MmsException | MissingConfigurationException e) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e);
markFailed(messageId, attachmentId);
} catch (InvalidMessageException e) {
Log.w(TAG, "Experienced an InvalidMessageException while trying to download an attachment.", e);
if (e.getCause() instanceof InvalidMacException) {
Log.w(TAG, "Detected an invalid mac. Treating as a permanent failure.");
markPermanentlyFailed(messageId, attachmentId);
} else {
markFailed(messageId, attachmentId);
}
}
}
private SignalServiceAttachmentPointer createAttachmentPointer(Attachment attachment) throws InvalidPartException {
if (TextUtils.isEmpty(attachment.remoteLocation)) {
throw new InvalidPartException("empty content id");
}
if (TextUtils.isEmpty(attachment.remoteKey)) {
throw new InvalidPartException("empty encrypted key");
}
try {
final SignalServiceAttachmentRemoteId remoteId = SignalServiceAttachmentRemoteId.from(attachment.remoteLocation);
final byte[] key = Base64.decode(attachment.remoteKey);
if (attachment.remoteDigest != null) {
Log.i(TAG, "Downloading attachment with digest: " + Hex.toString(attachment.remoteDigest));
} else {
Log.i(TAG, "Downloading attachment with no digest...");
}
return new SignalServiceAttachmentPointer(attachment.cdnNumber, remoteId, null, key,
Optional.of(Util.toIntExact(attachment.size)),
Optional.empty(),
0, 0,
Optional.ofNullable(attachment.remoteDigest),
Optional.ofNullable(attachment.getIncrementalDigest()),
attachment.incrementalMacChunkSize,
Optional.ofNullable(attachment.fileName),
attachment.voiceNote,
attachment.borderless,
attachment.videoGif,
Optional.empty(),
Optional.ofNullable(attachment.blurHash).map(BlurHash::getHash),
attachment.uploadTimestamp);
} catch (IOException | ArithmeticException e) {
Log.w(TAG, e);
throw new InvalidPartException(e);
}
}
private void retrieveAttachmentForReleaseChannel(long messageId,
final AttachmentId attachmentId,
final Attachment attachment)
throws IOException
{
try (Response response = S3.getObject(Objects.requireNonNull(attachment.fileName))) {
ResponseBody body = response.body();
if (body != null) {
if (body.contentLength() > FeatureFlags.maxAttachmentReceiveSizeBytes()) {
throw new MmsException("Attachment too large, failing download");
}
SignalDatabase.attachments().finalizeAttachmentAfterDownload(messageId, attachmentId, Okio.buffer(body.source()).inputStream());
}
} catch (MmsException e) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e);
markFailed(messageId, attachmentId);
}
}
private void markFailed(long messageId, AttachmentId attachmentId) {
try {
AttachmentTable database = SignalDatabase.attachments();
database.setTransferProgressFailed(attachmentId, messageId);
} catch (MmsException e) {
Log.w(TAG, e);
}
}
private void markPermanentlyFailed(long messageId, AttachmentId attachmentId) {
try {
AttachmentTable database = SignalDatabase.attachments();
database.setTransferProgressPermanentFailure(attachmentId, messageId);
} catch (MmsException e) {
Log.w(TAG, e);
}
}
public static boolean jobSpecMatchesAttachmentId(@NonNull JobSpec jobSpec, @NonNull AttachmentId attachmentId) {
if (!KEY.equals(jobSpec.getFactoryKey())) {
return false;
}
final byte[] serializedData = jobSpec.getSerializedData();
if (serializedData == null) {
return false;
}
JsonJobData data = JsonJobData.deserialize(serializedData);
final AttachmentId parsed = new AttachmentId(data.getLong(KEY_ATTACHMENT_ID));
return attachmentId.equals(parsed);
}
@VisibleForTesting
static class InvalidPartException extends Exception {
InvalidPartException(String s) {super(s);}
InvalidPartException(Exception e) {super(e);}
}
public static final class Factory implements Job.Factory<AttachmentDownloadJob> {
@Override
public @NonNull AttachmentDownloadJob create(@NonNull Parameters parameters, @Nullable byte[] serializedData) {
JsonJobData data = JsonJobData.deserialize(serializedData);
return new AttachmentDownloadJob(parameters,
data.getLong(KEY_MESSAGE_ID),
new AttachmentId(data.getLong(KEY_ATTACHMENT_ID)),
data.getBoolean(KEY_MANUAL));
}
}
}
@@ -0,0 +1,424 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.jobs
import android.text.TextUtils
import androidx.annotation.VisibleForTesting
import okio.Source
import okio.buffer
import org.greenrobot.eventbus.EventBus
import org.signal.core.util.Base64
import org.signal.core.util.Hex
import org.signal.core.util.logging.Log
import org.signal.libsignal.protocol.InvalidMacException
import org.signal.libsignal.protocol.InvalidMessageException
import org.thoughtcrime.securesms.attachments.Attachment
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.database.AttachmentTable
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.events.PartProgressEvent
import org.thoughtcrime.securesms.jobmanager.Job
import org.thoughtcrime.securesms.jobmanager.JobLogger.format
import org.thoughtcrime.securesms.jobmanager.JsonJobData
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint
import org.thoughtcrime.securesms.jobmanager.persistence.JobSpec
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.mms.MmsException
import org.thoughtcrime.securesms.notifications.v2.ConversationId.Companion.forConversation
import org.thoughtcrime.securesms.s3.S3
import org.thoughtcrime.securesms.transport.RetryLaterException
import org.thoughtcrime.securesms.util.AttachmentUtil
import org.thoughtcrime.securesms.util.FeatureFlags
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.backup.MediaName
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException
import org.whispersystems.signalservice.api.push.exceptions.RangeException
import java.io.File
import java.io.IOException
import java.util.Optional
import java.util.concurrent.TimeUnit
/**
* Download attachment from locations as specified in their record.
*/
class AttachmentDownloadJob private constructor(
parameters: Parameters,
private val messageId: Long,
attachmentId: AttachmentId,
private val manual: Boolean,
private var forceArchiveDownload: Boolean
) : BaseJob(parameters) {
companion object {
const val KEY = "AttachmentDownloadJob"
private val TAG = Log.tag(AttachmentDownloadJob::class.java)
private const val KEY_MESSAGE_ID = "message_id"
private const val KEY_ATTACHMENT_ID = "part_row_id"
private const val KEY_MANUAL = "part_manual"
private const val KEY_FORCE_ARCHIVE = "force_archive"
@JvmStatic
fun constructQueueString(attachmentId: AttachmentId): String {
return "AttachmentDownloadJob-" + attachmentId.id
}
fun jobSpecMatchesAttachmentId(jobSpec: JobSpec, attachmentId: AttachmentId): Boolean {
if (KEY != jobSpec.factoryKey) {
return false
}
val serializedData = jobSpec.serializedData ?: return false
val data = JsonJobData.deserialize(serializedData)
val parsed = AttachmentId(data.getLong(KEY_ATTACHMENT_ID))
return attachmentId == parsed
}
}
private val attachmentId: Long
constructor(messageId: Long, attachmentId: AttachmentId, manual: Boolean, forceArchiveDownload: Boolean = false) : this(
Parameters.Builder()
.setQueue(constructQueueString(attachmentId))
.addConstraint(NetworkConstraint.KEY)
.setLifespan(TimeUnit.DAYS.toMillis(1))
.setMaxAttempts(Parameters.UNLIMITED)
.build(),
messageId,
attachmentId,
manual,
forceArchiveDownload
)
init {
this.attachmentId = attachmentId.id
}
override fun serialize(): ByteArray? {
return JsonJobData.Builder()
.putLong(KEY_MESSAGE_ID, messageId)
.putLong(KEY_ATTACHMENT_ID, attachmentId)
.putBoolean(KEY_MANUAL, manual)
.putBoolean(KEY_FORCE_ARCHIVE, forceArchiveDownload)
.serialize()
}
override fun getFactoryKey(): String {
return KEY
}
override fun onAdded() {
Log.i(TAG, "onAdded() messageId: $messageId attachmentId: $attachmentId manual: $manual")
val attachmentId = AttachmentId(attachmentId)
val attachment = SignalDatabase.attachments.getAttachment(attachmentId)
val pending = attachment != null && attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_DONE && attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_PERMANENT_FAILURE
if (pending && (manual || AttachmentUtil.isAutoDownloadPermitted(context, attachment))) {
Log.i(TAG, "onAdded() Marking attachment progress as 'started'")
SignalDatabase.attachments.setTransferState(messageId, attachmentId, AttachmentTable.TRANSFER_PROGRESS_STARTED)
}
}
@Throws(Exception::class)
public override fun onRun() {
doWork()
if (!SignalDatabase.messages.isStory(messageId)) {
ApplicationDependencies.getMessageNotifier().updateNotification(context, forConversation(0))
}
}
@Throws(IOException::class, RetryLaterException::class)
fun doWork() {
Log.i(TAG, "onRun() messageId: $messageId attachmentId: $attachmentId manual: $manual")
val attachmentId = AttachmentId(attachmentId)
val attachment = SignalDatabase.attachments.getAttachment(attachmentId)
if (attachment == null) {
Log.w(TAG, "attachment no longer exists.")
return
}
if (attachment.isPermanentlyFailed) {
Log.w(TAG, "Attachment was marked as a permanent failure. Refusing to download.")
return
}
if (!attachment.isInProgress) {
Log.w(TAG, "Attachment was already downloaded.")
return
}
if (!manual && !AttachmentUtil.isAutoDownloadPermitted(context, attachment)) {
Log.w(TAG, "Attachment can't be auto downloaded...")
SignalDatabase.attachments.setTransferState(messageId, attachmentId, AttachmentTable.TRANSFER_PROGRESS_PENDING)
return
}
Log.i(TAG, "Downloading push part $attachmentId")
SignalDatabase.attachments.setTransferState(messageId, attachmentId, AttachmentTable.TRANSFER_PROGRESS_STARTED)
when (attachment.cdn) {
Cdn.S3 -> retrieveAttachmentForReleaseChannel(messageId, attachmentId, attachment)
else -> retrieveAttachment(messageId, attachmentId, attachment)
}
}
override fun onFailure() {
Log.w(TAG, format(this, "onFailure() messageId: $messageId attachmentId: $attachmentId manual: $manual"))
val attachmentId = AttachmentId(attachmentId)
markFailed(messageId, attachmentId)
}
override fun onShouldRetry(exception: Exception): Boolean {
return exception is PushNetworkException ||
exception is RetryLaterException
}
@Throws(IOException::class, RetryLaterException::class)
private fun retrieveAttachment(
messageId: Long,
attachmentId: AttachmentId,
attachment: DatabaseAttachment
) {
val maxReceiveSize: Long = FeatureFlags.maxAttachmentReceiveSizeBytes()
val attachmentFile: File = SignalDatabase.attachments.getOrCreateTransferFile(attachmentId)
var archiveFile: File? = null
var useArchiveCdn = false
try {
if (attachment.size > maxReceiveSize) {
throw MmsException("Attachment too large, failing download")
}
useArchiveCdn = if (SignalStore.backup().canReadWriteToArchiveCdn && (forceArchiveDownload || attachment.remoteLocation == null)) {
if (attachment.archiveMediaName.isNullOrEmpty()) {
throw InvalidPartException("Invalid attachment configuration")
}
true
} else {
false
}
val messageReceiver = ApplicationDependencies.getSignalServiceMessageReceiver()
val pointer = createAttachmentPointer(attachment, useArchiveCdn)
val progressListener = object : SignalServiceAttachment.ProgressListener {
override fun onAttachmentProgress(total: Long, progress: Long) {
EventBus.getDefault().postSticky(PartProgressEvent(attachment, PartProgressEvent.Type.NETWORK, total, progress))
}
override fun shouldCancel(): Boolean {
return this@AttachmentDownloadJob.isCanceled
}
}
val stream = if (useArchiveCdn) {
archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId)
val cdnCredentials = BackupRepository.getCdnReadCredentials().successOrThrow().headers
messageReceiver
.retrieveArchivedAttachment(
SignalStore.svr().getOrCreateMasterKey().deriveBackupKey().deriveMediaSecrets(MediaName(attachment.archiveMediaName!!)),
cdnCredentials,
archiveFile,
pointer,
attachmentFile,
maxReceiveSize,
progressListener
)
} else {
messageReceiver
.retrieveAttachment(
pointer,
attachmentFile,
maxReceiveSize,
progressListener
)
}
SignalDatabase.attachments.finalizeAttachmentAfterDownload(messageId, attachmentId, stream)
} catch (e: RangeException) {
val transferFile = archiveFile ?: attachmentFile
Log.w(TAG, "Range exception, file size " + transferFile.length(), e)
if (transferFile.delete()) {
Log.i(TAG, "Deleted temp download file to recover")
throw RetryLaterException(e)
} else {
throw IOException("Failed to delete temp download file following range exception")
}
} catch (e: InvalidPartException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
markFailed(messageId, attachmentId)
} catch (e: NonSuccessfulResponseCodeException) {
if (SignalStore.backup().canReadWriteToArchiveCdn) {
if (e.code == 404 && !useArchiveCdn && attachment.archiveMediaName?.isNotEmpty() == true) {
Log.i(TAG, "Retrying download from archive CDN")
forceArchiveDownload = true
retrieveAttachment(messageId, attachmentId, attachment)
return
} else if (e.code == 401 && useArchiveCdn) {
SignalStore.backup().cdnReadCredentials = null
throw RetryLaterException(e)
}
}
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
markFailed(messageId, attachmentId)
} catch (e: MmsException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
markFailed(messageId, attachmentId)
} catch (e: MissingConfigurationException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
markFailed(messageId, attachmentId)
} catch (e: InvalidMessageException) {
Log.w(TAG, "Experienced an InvalidMessageException while trying to download an attachment.", e)
if (e.cause is InvalidMacException) {
Log.w(TAG, "Detected an invalid mac. Treating as a permanent failure.")
markPermanentlyFailed(messageId, attachmentId)
} else {
markFailed(messageId, attachmentId)
}
}
}
@Throws(InvalidPartException::class)
private fun createAttachmentPointer(attachment: DatabaseAttachment, useArchiveCdn: Boolean): SignalServiceAttachmentPointer {
if (TextUtils.isEmpty(attachment.remoteKey)) {
throw InvalidPartException("empty encrypted key")
}
return try {
val remoteData: RemoteData = if (useArchiveCdn) {
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
val backupDirectories = BackupRepository.getCdnBackupDirectories().successOrThrow()
RemoteData(
remoteId = SignalServiceAttachmentRemoteId.Backup(
backupDir = backupDirectories.backupDir,
mediaDir = backupDirectories.mediaDir,
mediaId = backupKey.deriveMediaId(MediaName(attachment.archiveMediaName!!)).encode()
),
cdnNumber = attachment.archiveCdn
)
} else {
if (attachment.remoteLocation.isNullOrEmpty()) {
throw InvalidPartException("empty content id")
}
RemoteData(
remoteId = SignalServiceAttachmentRemoteId.from(attachment.remoteLocation),
cdnNumber = attachment.cdn.cdnNumber
)
}
val key = Base64.decode(attachment.remoteKey!!)
if (attachment.remoteDigest != null) {
Log.i(TAG, "Downloading attachment with digest: " + Hex.toString(attachment.remoteDigest))
} else {
Log.i(TAG, "Downloading attachment with no digest...")
}
SignalServiceAttachmentPointer(
remoteData.cdnNumber,
remoteData.remoteId,
null,
key,
Optional.of(Util.toIntExact(attachment.size)),
Optional.empty(),
0,
0,
Optional.ofNullable(attachment.remoteDigest),
Optional.ofNullable(attachment.getIncrementalDigest()),
attachment.incrementalMacChunkSize,
Optional.ofNullable(attachment.fileName),
attachment.voiceNote,
attachment.borderless,
attachment.videoGif,
Optional.empty(),
Optional.ofNullable(attachment.blurHash).map { it.hash },
attachment.uploadTimestamp
)
} catch (e: IOException) {
Log.w(TAG, e)
throw InvalidPartException(e)
} catch (e: ArithmeticException) {
Log.w(TAG, e)
throw InvalidPartException(e)
}
}
@Throws(IOException::class)
private fun retrieveAttachmentForReleaseChannel(
messageId: Long,
attachmentId: AttachmentId,
attachment: Attachment
) {
try {
S3.getObject(attachment.fileName!!).use { response ->
val body = response.body()
if (body != null) {
if (body.contentLength() > FeatureFlags.maxAttachmentReceiveSizeBytes()) {
throw MmsException("Attachment too large, failing download")
}
SignalDatabase.attachments.finalizeAttachmentAfterDownload(messageId, attachmentId, (body.source() as Source).buffer().inputStream())
}
}
} catch (e: MmsException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
markFailed(messageId, attachmentId)
}
}
private fun markFailed(messageId: Long, attachmentId: AttachmentId) {
try {
SignalDatabase.attachments.setTransferProgressFailed(attachmentId, messageId)
} catch (e: MmsException) {
Log.w(TAG, e)
}
}
private fun markPermanentlyFailed(messageId: Long, attachmentId: AttachmentId) {
try {
SignalDatabase.attachments.setTransferProgressPermanentFailure(attachmentId, messageId)
} catch (e: MmsException) {
Log.w(TAG, e)
}
}
@VisibleForTesting
internal class InvalidPartException : Exception {
constructor(s: String?) : super(s)
constructor(e: Exception?) : super(e)
}
private data class RemoteData(val remoteId: SignalServiceAttachmentRemoteId, val cdnNumber: Int)
class Factory : Job.Factory<AttachmentDownloadJob?> {
override fun create(parameters: Parameters, serializedData: ByteArray?): AttachmentDownloadJob {
val data = JsonJobData.deserialize(serializedData)
return AttachmentDownloadJob(
parameters = parameters,
messageId = data.getLong(KEY_MESSAGE_ID),
attachmentId = AttachmentId(data.getLong(KEY_ATTACHMENT_ID)),
manual = data.getBoolean(KEY_MANUAL),
forceArchiveDownload = data.getBooleanOrDefault(KEY_FORCE_ARCHIVE, false)
)
}
}
}
@@ -85,7 +85,7 @@ public final class AvatarGroupsV1DownloadJob extends BaseJob {
attachment.deleteOnExit(); attachment.deleteOnExit();
SignalServiceMessageReceiver receiver = ApplicationDependencies.getSignalServiceMessageReceiver(); SignalServiceMessageReceiver receiver = ApplicationDependencies.getSignalServiceMessageReceiver();
SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, Optional.empty(), 0, fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis()); SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId.V2(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, Optional.empty(), 0, fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis());
InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE); InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE);
AvatarHelper.setAvatar(context, record.get().getRecipientId(), inputStream); AvatarHelper.setAvatar(context, record.get().getRecipientId(), inputStream);
@@ -0,0 +1,113 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.jobs
import android.database.Cursor
import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.BuildConfig
import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.jobmanager.Job
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.providers.BlobProvider
import org.whispersystems.signalservice.api.NetworkResult
import java.io.FileInputStream
import java.io.FileOutputStream
/**
* Job that is responsible for exporting the DB as a backup proto and
* also uploading the resulting proto.
*/
class BackupMessagesJob private constructor(parameters: Parameters) : BaseJob(parameters) {
companion object {
private val TAG = Log.tag(BackupMessagesJob::class.java)
const val KEY = "BackupMessagesJob"
}
constructor() : this(
Parameters.Builder()
.addConstraint(NetworkConstraint.KEY)
.setMaxAttempts(Parameters.UNLIMITED)
.setMaxInstancesForFactory(2)
.build()
)
override fun serialize(): ByteArray? = null
override fun getFactoryKey(): String = KEY
override fun onFailure() = Unit
private fun archiveAttachments() {
if (BuildConfig.MESSAGE_BACKUP_RESTORE_ENABLED) {
SignalStore.backup().canReadWriteToArchiveCdn = true
}
val batchSize = 100
SignalDatabase.attachments.getArchivableAttachments().use { cursor ->
while (!cursor.isAfterLast) {
val attachments = cursor.readAttachmentBatch(batchSize)
when (val archiveResult = BackupRepository.archiveMedia(attachments)) {
is NetworkResult.Success -> {
for (success in archiveResult.result.sourceNotFoundResponses) {
val attachmentId = archiveResult.result.mediaIdToAttachmentId(success.mediaId)
ApplicationDependencies
.getJobManager()
.startChain(AttachmentUploadJob(attachmentId))
.then(ArchiveAttachmentJob(attachmentId))
.enqueue()
}
}
else -> {
Log.e(TAG, "Failed to archive $archiveResult")
}
}
}
}
}
private fun Cursor.readAttachmentBatch(batchSize: Int): List<DatabaseAttachment> {
val attachments = ArrayList<DatabaseAttachment>()
for (i in 0 until batchSize) {
if (this.moveToNext()) {
attachments.addAll(SignalDatabase.attachments.getAttachments(this))
} else {
break
}
}
return attachments
}
override fun onRun() {
val tempBackupFile = BlobProvider.getInstance().forNonAutoEncryptingSingleSessionOnDisk(ApplicationDependencies.getApplication())
val outputStream = FileOutputStream(tempBackupFile)
BackupRepository.export(outputStream = outputStream, append = { tempBackupFile.appendBytes(it) }, plaintext = false)
FileInputStream(tempBackupFile).use {
BackupRepository.uploadBackupFile(it, tempBackupFile.length())
}
archiveAttachments()
if (!tempBackupFile.delete()) {
Log.e(TAG, "Failed to delete temp backup file")
}
}
override fun onShouldRetry(e: Exception): Boolean = false
class Factory : Job.Factory<BackupMessagesJob> {
override fun create(parameters: Parameters, serializedData: ByteArray?): BackupMessagesJob {
return BackupMessagesJob(parameters)
}
}
}
@@ -0,0 +1,105 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.jobs
import org.signal.core.util.logging.Log
import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.thoughtcrime.securesms.R
import org.thoughtcrime.securesms.backup.RestoreState
import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.jobmanager.Job
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.net.NotPushRegisteredException
import org.thoughtcrime.securesms.providers.BlobProvider
import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.service.BackupProgressService
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener
import java.io.IOException
/**
* Job that is responsible for restoring a backup from the server
*/
class BackupRestoreJob private constructor(parameters: Parameters) : BaseJob(parameters) {
companion object {
private val TAG = Log.tag(BackupRestoreJob::class.java)
const val KEY = "BackupRestoreJob"
}
constructor() : this(
Parameters.Builder()
.addConstraint(NetworkConstraint.KEY)
.setMaxAttempts(Parameters.UNLIMITED)
.setMaxInstancesForFactory(1)
.build()
)
override fun serialize(): ByteArray? = null
override fun getFactoryKey(): String = KEY
override fun onFailure() = Unit
override fun onAdded() {
SignalStore.backup().restoreState = RestoreState.PENDING
}
override fun onRun() {
if (!SignalStore.account().isRegistered) {
Log.e(TAG, "Not registered, cannot restore!")
throw NotPushRegisteredException()
}
BackupProgressService.start(context, context.getString(R.string.BackupProgressService_title)).use {
restore(it)
}
}
private fun restore(controller: BackupProgressService.Controller) {
SignalStore.backup().restoreState = RestoreState.RESTORING_DB
val progressListener = object : ProgressListener {
override fun onAttachmentProgress(total: Long, progress: Long) {
controller.update(
title = context.getString(R.string.BackupProgressService_title_downloading),
progress = progress.toFloat() / total.toFloat(),
indeterminate = false
)
}
override fun shouldCancel() = isCanceled
}
val tempBackupFile = BlobProvider.getInstance().forNonAutoEncryptingSingleSessionOnDisk(ApplicationDependencies.getApplication())
if (!BackupRepository.downloadBackupFile(tempBackupFile, progressListener)) {
Log.e(TAG, "Failed to download backup file")
throw IOException()
}
controller.update(
title = context.getString(R.string.BackupProgressService_title),
progress = 0f,
indeterminate = true
)
val self = Recipient.self()
val selfData = BackupRepository.SelfData(self.aci.get(), self.pni.get(), self.e164.get(), ProfileKey(self.profileKey))
BackupRepository.import(length = tempBackupFile.length(), inputStreamFactory = tempBackupFile::inputStream, selfData = selfData, plaintext = false)
SignalStore.backup().restoreState = RestoreState.RESTORING_MEDIA
}
override fun onShouldRetry(e: Exception): Boolean = false
class Factory : Job.Factory<BackupRestoreJob> {
override fun create(parameters: Parameters, serializedData: ByteArray?): BackupRestoreJob {
return BackupRestoreJob(parameters)
}
}
}
@@ -0,0 +1,83 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.jobs
import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.database.model.MmsMessageRecord
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.jobmanager.Job
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.net.NotPushRegisteredException
import kotlin.time.Duration.Companion.days
/**
* Job that is responsible for enqueueing attachment download
* jobs upon restore.
*/
class BackupRestoreMediaJob private constructor(parameters: Parameters) : BaseJob(parameters) {
companion object {
private val TAG = Log.tag(BackupRestoreMediaJob::class.java)
const val KEY = "BackupRestoreMediaJob"
}
constructor() : this(
Parameters.Builder()
.addConstraint(NetworkConstraint.KEY)
.setMaxAttempts(Parameters.UNLIMITED)
.setMaxInstancesForFactory(2)
.build()
)
override fun serialize(): ByteArray? = null
override fun getFactoryKey(): String = KEY
override fun onFailure() = Unit
override fun onRun() {
if (!SignalStore.account().isRegistered) {
Log.e(TAG, "Not registered, cannot restore!")
throw NotPushRegisteredException()
}
val jobManager = ApplicationDependencies.getJobManager()
val batchSize = 100
val restoreTime = System.currentTimeMillis()
var restoreJobBatch: List<RestoreAttachmentJob>
do {
val attachmentBatch = SignalDatabase.attachments.getRestorableAttachments(batchSize)
val messageIds = attachmentBatch.map { it.mmsId }.toSet()
val messageMap = SignalDatabase.messages.getMessages(messageIds).associate { it.id to (it as MmsMessageRecord) }
restoreJobBatch = SignalDatabase.attachments.getRestorableAttachments(batchSize).map { attachment ->
val message = messageMap[attachment.mmsId]!!
RestoreAttachmentJob(
messageId = attachment.mmsId,
attachmentId = attachment.attachmentId,
manual = false,
forceArchiveDownload = true,
fullSize = shouldRestoreFullSize(message, restoreTime, optimizeStorage = SignalStore.backup().optimizeStorage)
)
}
jobManager.addAll(restoreJobBatch)
} while (restoreJobBatch.isNotEmpty())
}
private fun shouldRestoreFullSize(message: MmsMessageRecord, restoreTime: Long, optimizeStorage: Boolean): Boolean {
return ((restoreTime - message.dateSent) < 30.days.inWholeMilliseconds) || !optimizeStorage
}
override fun onShouldRetry(e: Exception): Boolean = false
class Factory : Job.Factory<BackupRestoreMediaJob> {
override fun create(parameters: Parameters, serializedData: ByteArray?): BackupRestoreMediaJob {
return BackupRestoreMediaJob(parameters)
}
}
}
@@ -100,6 +100,7 @@ public final class JobManagerFactories {
return new HashMap<String, Job.Factory>() {{ return new HashMap<String, Job.Factory>() {{
put(AccountConsistencyWorkerJob.KEY, new AccountConsistencyWorkerJob.Factory()); put(AccountConsistencyWorkerJob.KEY, new AccountConsistencyWorkerJob.Factory());
put(AnalyzeDatabaseJob.KEY, new AnalyzeDatabaseJob.Factory()); put(AnalyzeDatabaseJob.KEY, new AnalyzeDatabaseJob.Factory());
put(ArchiveAttachmentJob.KEY, new ArchiveAttachmentJob.Factory());
put(AttachmentCompressionJob.KEY, new AttachmentCompressionJob.Factory()); put(AttachmentCompressionJob.KEY, new AttachmentCompressionJob.Factory());
put(AttachmentCopyJob.KEY, new AttachmentCopyJob.Factory()); put(AttachmentCopyJob.KEY, new AttachmentCopyJob.Factory());
put(AttachmentDownloadJob.KEY, new AttachmentDownloadJob.Factory()); put(AttachmentDownloadJob.KEY, new AttachmentDownloadJob.Factory());
@@ -109,6 +110,9 @@ public final class JobManagerFactories {
put(AutomaticSessionResetJob.KEY, new AutomaticSessionResetJob.Factory()); put(AutomaticSessionResetJob.KEY, new AutomaticSessionResetJob.Factory());
put(AvatarGroupsV1DownloadJob.KEY, new AvatarGroupsV1DownloadJob.Factory()); put(AvatarGroupsV1DownloadJob.KEY, new AvatarGroupsV1DownloadJob.Factory());
put(AvatarGroupsV2DownloadJob.KEY, new AvatarGroupsV2DownloadJob.Factory()); put(AvatarGroupsV2DownloadJob.KEY, new AvatarGroupsV2DownloadJob.Factory());
put(BackupMessagesJob.KEY, new BackupMessagesJob.Factory());
put(BackupRestoreJob.KEY, new BackupRestoreJob.Factory());
put(BackupRestoreMediaJob.KEY, new BackupRestoreMediaJob.Factory());
put(BoostReceiptRequestResponseJob.KEY, new BoostReceiptRequestResponseJob.Factory()); put(BoostReceiptRequestResponseJob.KEY, new BoostReceiptRequestResponseJob.Factory());
put(CallLinkPeekJob.KEY, new CallLinkPeekJob.Factory()); put(CallLinkPeekJob.KEY, new CallLinkPeekJob.Factory());
put(CallLinkUpdateSendJob.KEY, new CallLinkUpdateSendJob.Factory()); put(CallLinkUpdateSendJob.KEY, new CallLinkUpdateSendJob.Factory());
@@ -193,6 +197,7 @@ public final class JobManagerFactories {
put(ResumableUploadSpecJob.KEY, new ResumableUploadSpecJob.Factory()); put(ResumableUploadSpecJob.KEY, new ResumableUploadSpecJob.Factory());
put(RequestGroupV2InfoWorkerJob.KEY, new RequestGroupV2InfoWorkerJob.Factory()); put(RequestGroupV2InfoWorkerJob.KEY, new RequestGroupV2InfoWorkerJob.Factory());
put(RequestGroupV2InfoJob.KEY, new RequestGroupV2InfoJob.Factory()); put(RequestGroupV2InfoJob.KEY, new RequestGroupV2InfoJob.Factory());
put(RestoreAttachmentJob.KEY, new RestoreAttachmentJob.Factory());
put(RetrieveProfileAvatarJob.KEY, new RetrieveProfileAvatarJob.Factory()); put(RetrieveProfileAvatarJob.KEY, new RetrieveProfileAvatarJob.Factory());
put(RetrieveProfileJob.KEY, new RetrieveProfileJob.Factory()); put(RetrieveProfileJob.KEY, new RetrieveProfileJob.Factory());
put(RetrieveRemoteAnnouncementsJob.KEY, new RetrieveRemoteAnnouncementsJob.Factory()); put(RetrieveRemoteAnnouncementsJob.KEY, new RetrieveRemoteAnnouncementsJob.Factory());
@@ -288,7 +288,7 @@ public abstract class PushSendJob extends SendJob {
} }
} }
return new SignalServiceAttachmentPointer(attachment.cdnNumber, return new SignalServiceAttachmentPointer(attachment.cdn.getCdnNumber(),
remoteId, remoteId,
attachment.contentType, attachment.contentType,
key, key,
@@ -0,0 +1,400 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.jobs
import android.text.TextUtils
import androidx.annotation.VisibleForTesting
import org.greenrobot.eventbus.EventBus
import org.signal.core.util.Base64
import org.signal.core.util.Hex
import org.signal.core.util.logging.Log
import org.signal.libsignal.protocol.InvalidMacException
import org.signal.libsignal.protocol.InvalidMessageException
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.database.AttachmentTable
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.events.PartProgressEvent
import org.thoughtcrime.securesms.jobmanager.Job
import org.thoughtcrime.securesms.jobmanager.JobLogger.format
import org.thoughtcrime.securesms.jobmanager.JsonJobData
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint
import org.thoughtcrime.securesms.jobmanager.persistence.JobSpec
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.mms.MmsException
import org.thoughtcrime.securesms.notifications.v2.ConversationId.Companion.forConversation
import org.thoughtcrime.securesms.transport.RetryLaterException
import org.thoughtcrime.securesms.util.FeatureFlags
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.backup.MediaName
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException
import org.whispersystems.signalservice.api.push.exceptions.RangeException
import java.io.File
import java.io.IOException
import java.util.Optional
import java.util.concurrent.TimeUnit
/**
* Download attachment from locations as specified in their record.
*/
class RestoreAttachmentJob private constructor(
parameters: Parameters,
private val messageId: Long,
attachmentId: AttachmentId,
private val manual: Boolean,
private var forceArchiveDownload: Boolean,
private val fullSize: Boolean
) : BaseJob(parameters) {
companion object {
const val KEY = "RestoreAttachmentJob"
private val TAG = Log.tag(AttachmentDownloadJob::class.java)
private const val KEY_MESSAGE_ID = "message_id"
private const val KEY_ATTACHMENT_ID = "part_row_id"
private const val KEY_MANUAL = "part_manual"
private const val KEY_FORCE_ARCHIVE = "force_archive"
private const val KEY_FULL_SIZE = "full_size"
@JvmStatic
fun constructQueueString(attachmentId: AttachmentId): String {
// TODO: decide how many queues
return "RestoreAttachmentJob"
}
fun jobSpecMatchesAnyAttachmentId(jobSpec: JobSpec, ids: Set<AttachmentId>): Boolean {
if (KEY != jobSpec.factoryKey) {
return false
}
val serializedData = jobSpec.serializedData ?: return false
val data = JsonJobData.deserialize(serializedData)
val parsed = AttachmentId(data.getLong(KEY_ATTACHMENT_ID))
return ids.contains(parsed)
}
fun modifyPriorities(ids: Set<AttachmentId>, priority: Int) {
val jobManager = ApplicationDependencies.getJobManager()
jobManager.update { spec ->
if (jobSpecMatchesAnyAttachmentId(spec, ids) && spec.priority != priority) {
spec.copy(priority = priority)
} else {
spec
}
}
}
}
private val attachmentId: Long
constructor(messageId: Long, attachmentId: AttachmentId, manual: Boolean, forceArchiveDownload: Boolean = false, fullSize: Boolean = true) : this(
Parameters.Builder()
.setQueue(constructQueueString(attachmentId))
.addConstraint(NetworkConstraint.KEY)
.setLifespan(TimeUnit.DAYS.toMillis(1))
.setMaxAttempts(Parameters.UNLIMITED)
.build(),
messageId,
attachmentId,
manual,
forceArchiveDownload,
fullSize
)
init {
this.attachmentId = attachmentId.id
}
override fun serialize(): ByteArray? {
return JsonJobData.Builder()
.putLong(KEY_MESSAGE_ID, messageId)
.putLong(KEY_ATTACHMENT_ID, attachmentId)
.putBoolean(KEY_MANUAL, manual)
.putBoolean(KEY_FORCE_ARCHIVE, forceArchiveDownload)
.putBoolean(KEY_FULL_SIZE, fullSize)
.serialize()
}
override fun getFactoryKey(): String {
return KEY
}
override fun onAdded() {
Log.i(TAG, "onAdded() messageId: $messageId attachmentId: $attachmentId manual: $manual")
val attachmentId = AttachmentId(attachmentId)
val attachment = SignalDatabase.attachments.getAttachment(attachmentId)
val pending = attachment != null && attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_DONE && attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_PERMANENT_FAILURE
if (attachment?.transferState == AttachmentTable.TRANSFER_NEEDS_RESTORE) {
Log.i(TAG, "onAdded() Marking attachment restore progress as 'started'")
SignalDatabase.attachments.setTransferState(messageId, attachmentId, AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS)
}
}
@Throws(Exception::class)
public override fun onRun() {
doWork()
if (!SignalDatabase.messages.isStory(messageId)) {
ApplicationDependencies.getMessageNotifier().updateNotification(context, forConversation(0))
}
}
@Throws(IOException::class, RetryLaterException::class)
fun doWork() {
Log.i(TAG, "onRun() messageId: $messageId attachmentId: $attachmentId manual: $manual")
val attachmentId = AttachmentId(attachmentId)
val attachment = SignalDatabase.attachments.getAttachment(attachmentId)
if (attachment == null) {
Log.w(TAG, "attachment no longer exists.")
return
}
if (attachment.isPermanentlyFailed) {
Log.w(TAG, "Attachment was marked as a permanent failure. Refusing to download.")
return
}
if (attachment.transferState != AttachmentTable.TRANSFER_NEEDS_RESTORE && attachment.transferState != AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS) {
Log.w(TAG, "Attachment does not need to be restored.")
return
}
retrieveAttachment(messageId, attachmentId, attachment)
}
override fun onFailure() {
Log.w(TAG, format(this, "onFailure() messageId: $messageId attachmentId: $attachmentId manual: $manual"))
val attachmentId = AttachmentId(attachmentId)
markFailed(messageId, attachmentId)
}
override fun onShouldRetry(exception: Exception): Boolean {
return exception is PushNetworkException ||
exception is RetryLaterException
}
@Throws(IOException::class, RetryLaterException::class)
private fun retrieveAttachment(
messageId: Long,
attachmentId: AttachmentId,
attachment: DatabaseAttachment
) {
val maxReceiveSize: Long = FeatureFlags.maxAttachmentReceiveSizeBytes()
val attachmentFile: File = SignalDatabase.attachments.getOrCreateTransferFile(attachmentId)
var archiveFile: File? = null
var useArchiveCdn = false
try {
if (attachment.size > maxReceiveSize) {
throw MmsException("Attachment too large, failing download")
}
useArchiveCdn = if (SignalStore.backup().canReadWriteToArchiveCdn && (forceArchiveDownload || attachment.remoteLocation == null)) {
if (attachment.archiveMediaName.isNullOrEmpty()) {
throw InvalidPartException("Invalid attachment configuration")
}
true
} else {
false
}
val messageReceiver = ApplicationDependencies.getSignalServiceMessageReceiver()
val pointer = createAttachmentPointer(attachment, useArchiveCdn)
val progressListener = object : SignalServiceAttachment.ProgressListener {
override fun onAttachmentProgress(total: Long, progress: Long) {
EventBus.getDefault().postSticky(PartProgressEvent(attachment, PartProgressEvent.Type.NETWORK, total, progress))
}
override fun shouldCancel(): Boolean {
return this@RestoreAttachmentJob.isCanceled
}
}
val stream = if (useArchiveCdn) {
archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId)
val cdnCredentials = BackupRepository.getCdnReadCredentials().successOrThrow().headers
messageReceiver
.retrieveArchivedAttachment(
SignalStore.svr().getOrCreateMasterKey().deriveBackupKey().deriveMediaSecrets(MediaName(attachment.archiveMediaName!!)),
cdnCredentials,
archiveFile,
pointer,
attachmentFile,
maxReceiveSize,
progressListener
)
} else {
messageReceiver
.retrieveAttachment(
pointer,
attachmentFile,
maxReceiveSize,
progressListener
)
}
SignalDatabase.attachments.finalizeAttachmentAfterDownload(messageId, attachmentId, stream)
} catch (e: RangeException) {
val transferFile = archiveFile ?: attachmentFile
Log.w(TAG, "Range exception, file size " + transferFile.length(), e)
if (transferFile.delete()) {
Log.i(TAG, "Deleted temp download file to recover")
throw RetryLaterException(e)
} else {
throw IOException("Failed to delete temp download file following range exception")
}
} catch (e: InvalidPartException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
markFailed(messageId, attachmentId)
} catch (e: NonSuccessfulResponseCodeException) {
if (SignalStore.backup().canReadWriteToArchiveCdn) {
if (e.code == 404 && !useArchiveCdn && attachment.archiveMediaName?.isNotEmpty() == true) {
Log.i(TAG, "Retrying download from archive CDN")
forceArchiveDownload = true
retrieveAttachment(messageId, attachmentId, attachment)
return
} else if (e.code == 401 && useArchiveCdn) {
SignalStore.backup().cdnReadCredentials = null
throw RetryLaterException(e)
}
}
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
markFailed(messageId, attachmentId)
} catch (e: MmsException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
markFailed(messageId, attachmentId)
} catch (e: MissingConfigurationException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
markFailed(messageId, attachmentId)
} catch (e: InvalidMessageException) {
Log.w(TAG, "Experienced an InvalidMessageException while trying to download an attachment.", e)
if (e.cause is InvalidMacException) {
Log.w(TAG, "Detected an invalid mac. Treating as a permanent failure.")
markPermanentlyFailed(messageId, attachmentId)
} else {
markFailed(messageId, attachmentId)
}
}
}
@Throws(InvalidPartException::class)
private fun createAttachmentPointer(attachment: DatabaseAttachment, useArchiveCdn: Boolean): SignalServiceAttachmentPointer {
if (TextUtils.isEmpty(attachment.remoteKey)) {
throw InvalidPartException("empty encrypted key")
}
return try {
val remoteData: RemoteData = if (useArchiveCdn) {
val backupKey = SignalStore.svr().getOrCreateMasterKey().deriveBackupKey()
val backupDirectories = BackupRepository.getCdnBackupDirectories().successOrThrow()
RemoteData(
remoteId = SignalServiceAttachmentRemoteId.Backup(
backupDir = backupDirectories.backupDir,
mediaDir = backupDirectories.mediaDir,
mediaId = backupKey.deriveMediaId(MediaName(attachment.archiveMediaName!!)).encode()
),
cdnNumber = attachment.archiveCdn
)
} else {
if (attachment.remoteLocation.isNullOrEmpty()) {
throw InvalidPartException("empty content id")
}
RemoteData(
remoteId = SignalServiceAttachmentRemoteId.from(attachment.remoteLocation),
cdnNumber = attachment.cdn.cdnNumber
)
}
val key = Base64.decode(attachment.remoteKey!!)
if (attachment.remoteDigest != null) {
Log.i(TAG, "Downloading attachment with digest: " + Hex.toString(attachment.remoteDigest))
} else {
Log.i(TAG, "Downloading attachment with no digest...")
}
SignalServiceAttachmentPointer(
remoteData.cdnNumber,
remoteData.remoteId,
null,
key,
Optional.of(Util.toIntExact(attachment.size)),
Optional.empty(),
0,
0,
Optional.ofNullable(attachment.remoteDigest),
Optional.ofNullable(attachment.getIncrementalDigest()),
attachment.incrementalMacChunkSize,
Optional.ofNullable(attachment.fileName),
attachment.voiceNote,
attachment.borderless,
attachment.videoGif,
Optional.empty(),
Optional.ofNullable(attachment.blurHash).map { it.hash },
attachment.uploadTimestamp
)
} catch (e: IOException) {
Log.w(TAG, e)
throw InvalidPartException(e)
} catch (e: ArithmeticException) {
Log.w(TAG, e)
throw InvalidPartException(e)
}
}
private fun markFailed(messageId: Long, attachmentId: AttachmentId) {
try {
SignalDatabase.attachments.setTransferProgressFailed(attachmentId, messageId)
} catch (e: MmsException) {
Log.w(TAG, e)
}
}
private fun markPermanentlyFailed(messageId: Long, attachmentId: AttachmentId) {
try {
SignalDatabase.attachments.setTransferProgressPermanentFailure(attachmentId, messageId)
} catch (e: MmsException) {
Log.w(TAG, e)
}
}
@VisibleForTesting
internal class InvalidPartException : Exception {
constructor(s: String?) : super(s)
constructor(e: Exception?) : super(e)
}
private data class RemoteData(val remoteId: SignalServiceAttachmentRemoteId, val cdnNumber: Int)
class Factory : Job.Factory<RestoreAttachmentJob?> {
override fun create(parameters: Parameters, serializedData: ByteArray?): RestoreAttachmentJob {
val data = JsonJobData.deserialize(serializedData)
return RestoreAttachmentJob(
parameters = parameters,
messageId = data.getLong(KEY_MESSAGE_ID),
attachmentId = AttachmentId(data.getLong(KEY_ATTACHMENT_ID)),
manual = data.getBoolean(KEY_MANUAL),
forceArchiveDownload = data.getBooleanOrDefault(KEY_FORCE_ARCHIVE, false),
fullSize = data.getBooleanOrDefault(KEY_FULL_SIZE, true)
)
}
}
}
@@ -2,21 +2,44 @@ package org.thoughtcrime.securesms.keyvalue
import com.fasterxml.jackson.annotation.JsonProperty import com.fasterxml.jackson.annotation.JsonProperty
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.backup.RestoreState
import org.whispersystems.signalservice.api.archive.ArchiveServiceCredential import org.whispersystems.signalservice.api.archive.ArchiveServiceCredential
import org.whispersystems.signalservice.api.archive.GetArchiveCdnCredentialsResponse
import org.whispersystems.signalservice.internal.util.JsonUtil import org.whispersystems.signalservice.internal.util.JsonUtil
import java.io.IOException import java.io.IOException
import kotlin.time.Duration import kotlin.time.Duration
import kotlin.time.Duration.Companion.days import kotlin.time.Duration.Companion.days
import kotlin.time.Duration.Companion.hours
internal class BackupValues(store: KeyValueStore) : SignalStoreValues(store) { internal class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
companion object { companion object {
val TAG = Log.tag(BackupValues::class.java) val TAG = Log.tag(BackupValues::class.java)
val KEY_CREDENTIALS = "backup.credentials" private const val KEY_CREDENTIALS = "backup.credentials"
private const val KEY_CDN_CAN_READ_WRITE = "backup.cdn.canReadWrite"
private const val KEY_CDN_READ_CREDENTIALS = "backup.cdn.readCredentials"
private const val KEY_CDN_READ_CREDENTIALS_TIMESTAMP = "backup.cdn.readCredentials.timestamp"
private const val KEY_RESTORE_STATE = "backup.restoreState"
private const val KEY_CDN_BACKUP_DIRECTORY = "backup.cdn.directory"
private const val KEY_CDN_BACKUP_MEDIA_DIRECTORY = "backup.cdn.mediaDirectory"
private const val KEY_OPTIMIZE_STORAGE = "backup.optimizeStorage"
private val cachedCdnCredentialsExpiresIn: Duration = 12.hours
} }
private var cachedCdnCredentialsTimestamp: Long by longValue(KEY_CDN_READ_CREDENTIALS_TIMESTAMP, 0L)
private var cachedCdnCredentials: String? by stringValue(KEY_CDN_READ_CREDENTIALS, null)
var cachedBackupDirectory: String? by stringValue(KEY_CDN_BACKUP_DIRECTORY, null)
var cachedBackupMediaDirectory: String? by stringValue(KEY_CDN_BACKUP_MEDIA_DIRECTORY, null)
override fun onFirstEverAppLaunch() = Unit override fun onFirstEverAppLaunch() = Unit
override fun getKeysToIncludeInBackup(): List<String> = emptyList() override fun getKeysToIncludeInBackup(): List<String> = emptyList()
var canReadWriteToArchiveCdn: Boolean by booleanValue(KEY_CDN_CAN_READ_WRITE, false)
var restoreState: RestoreState by enumValue(KEY_RESTORE_STATE, RestoreState.NONE, RestoreState.serializer)
var optimizeStorage: Boolean by booleanValue(KEY_OPTIMIZE_STORAGE, false)
/** /**
* Retrieves the stored credentials, mapped by the day they're valid. The day is represented as * Retrieves the stored credentials, mapped by the day they're valid. The day is represented as
* the unix time (in seconds) of the start of the day. Wrapped in a [ArchiveServiceCredentials] * the unix time (in seconds) of the start of the day. Wrapped in a [ArchiveServiceCredentials]
@@ -36,6 +59,28 @@ internal class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
} }
} }
var cdnReadCredentials: GetArchiveCdnCredentialsResponse?
get() {
val cacheAge = System.currentTimeMillis() - cachedCdnCredentialsTimestamp
val cached = cachedCdnCredentials
return if (cached != null && (cacheAge > 0 && cacheAge < cachedCdnCredentialsExpiresIn.inWholeMilliseconds)) {
try {
JsonUtil.fromJson(cached, GetArchiveCdnCredentialsResponse::class.java)
} catch (e: IOException) {
Log.w(TAG, "Invalid JSON! Clearing.", e)
cachedCdnCredentials = null
null
}
} else {
null
}
}
set(value) {
cachedCdnCredentials = value?.let { JsonUtil.toJson(it) }
cachedCdnCredentialsTimestamp = System.currentTimeMillis()
}
/** /**
* Adds the given credentials to the existing list of stored credentials. * Adds the given credentials to the existing list of stored credentials.
*/ */
@@ -255,7 +255,7 @@ public class LegacyMigrationJob extends MigrationJob {
attachmentDb.setTransferState(attachment.mmsId, attachment.attachmentId, AttachmentTable.TRANSFER_PROGRESS_DONE); attachmentDb.setTransferState(attachment.mmsId, attachment.attachmentId, AttachmentTable.TRANSFER_PROGRESS_DONE);
} else if (record != null && !record.isOutgoing() && record.isPush()) { } else if (record != null && !record.isOutgoing() && record.isPush()) {
Log.i(TAG, "queuing new attachment download job for incoming push part " + attachment.attachmentId + "."); Log.i(TAG, "queuing new attachment download job for incoming push part " + attachment.attachmentId + ".");
ApplicationDependencies.getJobManager().add(new AttachmentDownloadJob(attachment.mmsId, attachment.attachmentId, false)); ApplicationDependencies.getJobManager().add(new AttachmentDownloadJob(attachment.mmsId, attachment.attachmentId, false, false));
} }
reader.close(); reader.close();
} }
@@ -8,6 +8,7 @@ public final class NotificationIds {
public static final int FCM_FAILURE = 12; public static final int FCM_FAILURE = 12;
public static final int ATTACHMENT_PROGRESS = 50; public static final int ATTACHMENT_PROGRESS = 50;
public static final int BACKUP_PROGRESS = 51;
public static final int APK_UPDATE_PROMPT_INSTALL = 666; public static final int APK_UPDATE_PROMPT_INSTALL = 666;
public static final int APK_UPDATE_FAILED_INSTALL = 667; public static final int APK_UPDATE_FAILED_INSTALL = 667;
public static final int APK_UPDATE_SUCCESSFUL_INSTALL = 668; public static final int APK_UPDATE_SUCCESSFUL_INSTALL = 668;
@@ -1,5 +1,6 @@
package org.thoughtcrime.securesms.releasechannel package org.thoughtcrime.securesms.releasechannel
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.PointerAttachment import org.thoughtcrime.securesms.attachments.PointerAttachment
import org.thoughtcrime.securesms.database.MessageTable import org.thoughtcrime.securesms.database.MessageTable
import org.thoughtcrime.securesms.database.MessageType import org.thoughtcrime.securesms.database.MessageType
@@ -20,8 +21,6 @@ import java.util.UUID
*/ */
object ReleaseChannel { object ReleaseChannel {
const val CDN_NUMBER = -1
fun insertReleaseChannelMessage( fun insertReleaseChannelMessage(
recipientId: RecipientId, recipientId: RecipientId,
body: String, body: String,
@@ -36,8 +35,8 @@ object ReleaseChannel {
): MessageTable.InsertResult? { ): MessageTable.InsertResult? {
val attachments: Optional<List<SignalServiceAttachment>> = if (media != null) { val attachments: Optional<List<SignalServiceAttachment>> = if (media != null) {
val attachment = SignalServiceAttachmentPointer( val attachment = SignalServiceAttachmentPointer(
CDN_NUMBER, Cdn.S3.cdnNumber,
SignalServiceAttachmentRemoteId.from(""), SignalServiceAttachmentRemoteId.S3,
mediaType, mediaType,
null, null,
Optional.empty(), Optional.empty(),
@@ -150,8 +150,8 @@ class AttachmentProgressService : SafeForegroundService() {
/** Has to have separate setter to avoid infinite loops when [progress] and [indeterminate] interact. */ /** Has to have separate setter to avoid infinite loops when [progress] and [indeterminate] interact. */
fun setIndeterminate(value: Boolean) { fun setIndeterminate(value: Boolean) {
indeterminate = value
progress = 0f progress = 0f
indeterminate = value
onControllersChanged(context) onControllersChanged(context)
} }
@@ -0,0 +1,120 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.service
import android.annotation.SuppressLint
import android.app.Notification
import android.app.PendingIntent
import android.content.Context
import android.content.Intent
import androidx.core.app.NotificationCompat
import androidx.core.app.NotificationManagerCompat
import org.signal.core.util.PendingIntentFlags
import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.MainActivity
import org.thoughtcrime.securesms.R
import org.thoughtcrime.securesms.notifications.NotificationChannels
import org.thoughtcrime.securesms.notifications.NotificationIds
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.CheckReturnValue
import kotlin.concurrent.withLock
/**
* Foreground service to provide "long" run support to backup jobs.
*/
class BackupProgressService : SafeForegroundService() {
companion object {
private val TAG = Log.tag(BackupProgressService::class)
@SuppressLint("StaticFieldLeak")
private var controller: Controller? = null
private val controllerLock = ReentrantLock()
private var title: String = ""
private var progress: Float = 0f
private var indeterminate: Boolean = true
@CheckReturnValue
fun start(context: Context, startingTitle: String): Controller {
controllerLock.withLock {
if (controller != null) {
Log.w(TAG, "Starting service with existing controller")
}
controller = Controller(context, startingTitle)
val started = SafeForegroundService.start(context, BackupProgressService::class.java)
if (started) {
Log.i(TAG, "[start] Starting")
} else {
Log.w(TAG, "[start] Service already started")
}
return controller!!
}
}
private fun stop(context: Context) {
SafeForegroundService.stop(context, BackupProgressService::class.java)
controllerLock.withLock {
controller = null
}
}
private fun getForegroundNotification(context: Context): Notification {
return NotificationCompat.Builder(context, NotificationChannels.getInstance().OTHER)
.setSmallIcon(R.drawable.ic_notification)
.setContentTitle(title)
.setProgress(100, (progress * 100).toInt(), indeterminate)
.setContentIntent(PendingIntent.getActivity(context, 0, MainActivity.clearTop(context), PendingIntentFlags.mutable()))
.setVibrate(longArrayOf(0))
.build()
}
}
override val tag: String = TAG
override val notificationId: Int = NotificationIds.BACKUP_PROGRESS
override fun getForegroundNotification(intent: Intent): Notification {
return getForegroundNotification(this)
}
/**
* Use to update notification progress/state.
*/
class Controller(private val context: Context, startingTitle: String) : AutoCloseable {
init {
title = startingTitle
progress = 0f
indeterminate = true
}
fun update(title: String, progress: Float, indeterminate: Boolean) {
controllerLock.withLock {
if (this != controller) {
return
}
BackupProgressService.title = title
BackupProgressService.progress = progress
BackupProgressService.indeterminate = indeterminate
if (NotificationManagerCompat.from(context).activeNotifications.any { n -> n.id == NotificationIds.BACKUP_PROGRESS }) {
NotificationManagerCompat.from(context).notify(NotificationIds.BACKUP_PROGRESS, getForegroundNotification(context))
}
}
}
override fun close() {
controllerLock.withLock {
if (this == controller) {
stop(context)
}
}
}
}
}
@@ -141,7 +141,7 @@ object Stories {
if (record.hasLinkPreview() && record.linkPreviews[0].attachmentId != null) { if (record.hasLinkPreview() && record.linkPreviews[0].attachmentId != null) {
ApplicationDependencies.getJobManager().add( ApplicationDependencies.getJobManager().add(
AttachmentDownloadJob(record.id, record.linkPreviews[0].attachmentId, true) AttachmentDownloadJob(record.id, record.linkPreviews[0].attachmentId!!, true)
) )
} }
} }
+22 -22
View File
@@ -381,6 +381,7 @@ message MessageAttachment {
FilePointer pointer = 1; FilePointer pointer = 1;
Flag flag = 2; Flag flag = 2;
bool wasDownloaded = 3;
} }
message FilePointer { message FilePointer {
@@ -388,6 +389,9 @@ message FilePointer {
message BackupLocator { message BackupLocator {
string mediaName = 1; string mediaName = 1;
uint32 cdnNumber = 2; uint32 cdnNumber = 2;
bytes key = 3;
bytes digest = 4;
uint32 size = 5;
} }
// References attachments in the transit storage tier. // References attachments in the transit storage tier.
@@ -398,37 +402,33 @@ message FilePointer {
string cdnKey = 1; string cdnKey = 1;
uint32 cdnNumber = 2; uint32 cdnNumber = 2;
uint64 uploadTimestamp = 3; uint64 uploadTimestamp = 3;
bytes key = 4;
bytes digest = 5;
uint32 size = 6;
} }
// An attachment that was copied from the transit storage tier // References attachments that are invalid in such a way where download
// to the backup (media) storage tier up without being downloaded. // cannot be attempted. Could range from missing digests to missing
// Its MediaName should be generated as {sender_aci}_{cdn_attachment_key}, // CDN keys or anything else that makes download attempts impossible.
// but should eventually transition to a BackupLocator with mediaName // This serves as a 'tombstone' so that the UX can show that an attachment
// being the content hash once it is downloaded. // did exist, but for whatever reason it's not retrievable.
message UndownloadedBackupLocator { message InvalidAttachmentLocator {
bytes senderAci = 1;
string cdnKey = 2;
uint32 cdnNumber = 3;
} }
oneof locator { oneof locator {
BackupLocator backupLocator = 1; BackupLocator backupLocator = 1;
AttachmentLocator attachmentLocator= 2; AttachmentLocator attachmentLocator= 2;
UndownloadedBackupLocator undownloadedBackupLocator = 3; InvalidAttachmentLocator invalidAttachmentLocator = 3;
} }
optional bytes key = 5; optional string contentType = 4;
optional string contentType = 6; optional bytes incrementalMac = 5;
// Size of fullsize decrypted media blob in bytes. optional uint32 incrementalMacChunkSize = 6;
// Can be ignored if unset/unavailable. optional string fileName = 7;
optional uint32 size = 7; optional uint32 width = 8;
optional bytes incrementalMac = 8; optional uint32 height = 9;
optional uint32 incrementalMacChunkSize = 9; optional string caption = 10;
optional string fileName = 10; optional string blurHash = 11;
optional uint32 width = 11;
optional uint32 height = 12;
optional string caption = 13;
optional string blurHash = 14;
} }
message Quote { message Quote {
+4
View File
@@ -47,3 +47,7 @@ message AttachmentUploadJobData {
message PreKeysSyncJobData { message PreKeysSyncJobData {
bool forceRefreshRequested = 1; bool forceRefreshRequested = 1;
} }
message ArchiveAttachmentJobData {
uint64 attachmentId = 1;
}
+6
View File
@@ -6707,5 +6707,11 @@
<!-- Content description for opening the note editor --> <!-- Content description for opening the note editor -->
<string name="ViewNoteSheet__edit_note">Edit note</string> <string name="ViewNoteSheet__edit_note">Edit note</string>
<!-- BackupProgressService -->
<!-- Notification title shown while backup restore job is running -->
<string name="BackupProgressService_title">Restoring backup…</string>
<!-- Notification title shown while downloading backup restore data -->
<string name="BackupProgressService_title_downloading">Downloading backup data…</string>
<!-- EOF --> <!-- EOF -->
</resources> </resources>
@@ -236,7 +236,7 @@ class UploadDependencyGraphTest {
transferProgress = AttachmentTable.TRANSFER_PROGRESS_PENDING, transferProgress = AttachmentTable.TRANSFER_PROGRESS_PENDING,
size = attachment.size, size = attachment.size,
fileName = attachment.fileName, fileName = attachment.fileName,
cdnNumber = attachment.cdnNumber, cdn = attachment.cdn,
location = attachment.remoteLocation, location = attachment.remoteLocation,
key = attachment.remoteKey, key = attachment.remoteKey,
digest = attachment.remoteDigest, digest = attachment.remoteDigest,
@@ -256,7 +256,10 @@ class UploadDependencyGraphTest {
transformProperties = attachment.transformProperties, transformProperties = attachment.transformProperties,
displayOrder = 0, displayOrder = 0,
uploadTimestamp = attachment.uploadTimestamp, uploadTimestamp = attachment.uploadTimestamp,
dataHash = null dataHash = null,
archiveMediaId = null,
archiveMediaName = null,
archiveCdn = 0
) )
} }
@@ -1,6 +1,7 @@
package org.thoughtcrime.securesms.database package org.thoughtcrime.securesms.database
import org.thoughtcrime.securesms.attachments.AttachmentId import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.DatabaseAttachment import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.audio.AudioHash import org.thoughtcrime.securesms.audio.AudioHash
import org.thoughtcrime.securesms.blurhash.BlurHash import org.thoughtcrime.securesms.blurhash.BlurHash
@@ -35,7 +36,7 @@ object FakeMessageRecords {
transferProgress: Int = AttachmentTable.TRANSFER_PROGRESS_DONE, transferProgress: Int = AttachmentTable.TRANSFER_PROGRESS_DONE,
size: Long = 0L, size: Long = 0L,
fileName: String = "", fileName: String = "",
cdnNumber: Int = 1, cdnNumber: Int = 3,
location: String = "", location: String = "",
key: String = "", key: String = "",
relay: String = "", relay: String = "",
@@ -56,7 +57,10 @@ object FakeMessageRecords {
transformProperties: AttachmentTable.TransformProperties? = null, transformProperties: AttachmentTable.TransformProperties? = null,
displayOrder: Int = 0, displayOrder: Int = 0,
uploadTimestamp: Long = 200, uploadTimestamp: Long = 200,
dataHash: String? = null dataHash: String? = null,
archiveCdn: Int = 0,
archiveMediaName: String? = null,
archiveMediaId: String? = null
): DatabaseAttachment { ): DatabaseAttachment {
return DatabaseAttachment( return DatabaseAttachment(
attachmentId, attachmentId,
@@ -67,7 +71,7 @@ object FakeMessageRecords {
transferProgress, transferProgress,
size, size,
fileName, fileName,
cdnNumber, Cdn.fromCdnNumber(cdnNumber),
location, location,
key, key,
digest, digest,
@@ -87,7 +91,10 @@ object FakeMessageRecords {
transformProperties, transformProperties,
displayOrder, displayOrder,
uploadTimestamp, uploadTimestamp,
dataHash dataHash,
archiveCdn,
archiveMediaId,
archiveMediaName
) )
} }
@@ -5,6 +5,8 @@
package org.signal.core.util.logging package org.signal.core.util.logging
import kotlin.reflect.KClass
object Log { object Log {
private val NOOP_LOGGER: Logger = NoopLogger() private val NOOP_LOGGER: Logger = NoopLogger()
private var internalCheck: InternalCheck? = null private var internalCheck: InternalCheck? = null
@@ -102,6 +104,11 @@ object Log {
@JvmStatic @JvmStatic
fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = logger.e(tag, message, t, keepLonger) fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = logger.e(tag, message, t, keepLonger)
@JvmStatic
fun tag(clazz: KClass<*>): String {
return tag(clazz.java)
}
@JvmStatic @JvmStatic
fun tag(clazz: Class<*>): String { fun tag(clazz: Class<*>): String {
val simpleName = clazz.simpleName val simpleName = clazz.simpleName
@@ -6,13 +6,17 @@
package org.whispersystems.signalservice.api; package org.whispersystems.signalservice.api;
import org.signal.core.util.StreamUtil;
import org.signal.core.util.concurrent.FutureTransformers; import org.signal.core.util.concurrent.FutureTransformers;
import org.signal.core.util.concurrent.ListenableFuture; import org.signal.core.util.concurrent.ListenableFuture;
import org.signal.core.util.concurrent.SettableFuture; import org.signal.core.util.concurrent.SettableFuture;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations; import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations;
import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.whispersystems.signalservice.api.backup.BackupKey;
import org.whispersystems.signalservice.api.backup.MediaId;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream; import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess; import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener; import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener;
@@ -27,6 +31,7 @@ import org.whispersystems.signalservice.api.push.exceptions.MissingConfiguration
import org.whispersystems.signalservice.api.util.CredentialsProvider; import org.whispersystems.signalservice.api.util.CredentialsProvider;
import org.whispersystems.signalservice.internal.ServiceResponse; import org.whispersystems.signalservice.internal.ServiceResponse;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
import org.whispersystems.signalservice.internal.push.IdentityCheckRequest; import org.whispersystems.signalservice.internal.push.IdentityCheckRequest;
import org.whispersystems.signalservice.internal.push.IdentityCheckResponse; import org.whispersystems.signalservice.internal.push.IdentityCheckResponse;
import org.whispersystems.signalservice.internal.push.PushServiceSocket; import org.whispersystems.signalservice.internal.push.PushServiceSocket;
@@ -36,14 +41,18 @@ import org.whispersystems.signalservice.internal.websocket.ResponseMapper;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.Single;
@@ -159,10 +168,60 @@ public class SignalServiceMessageReceiver {
throws IOException, InvalidMessageException, MissingConfigurationException { throws IOException, InvalidMessageException, MissingConfigurationException {
if (!pointer.getDigest().isPresent()) throw new InvalidMessageException("No attachment digest!"); if (!pointer.getDigest().isPresent()) throw new InvalidMessageException("No attachment digest!");
socket.retrieveAttachment(pointer.getCdnNumber(), pointer.getRemoteId(), destination, maxSizeBytes, listener); socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener);
return AttachmentCipherInputStream.createForAttachment(destination, pointer.getSize().orElse(0), pointer.getKey(), pointer.getDigest().get(), null, 0); return AttachmentCipherInputStream.createForAttachment(destination, pointer.getSize().orElse(0), pointer.getKey(), pointer.getDigest().get(), null, 0);
} }
/**
* Retrieves an archived media attachment.
*
* @param archivedMediaKeyMaterial Decryption key material for decrypting outer layer of archived media.
* @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download
* @param archiveDestination The download destination for archived attachment. If this file exists, download will resume.
* @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}.
* @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed.
* @param listener An optional listener (may be null) to receive callbacks on download progress.
*
* @return An InputStream that streams the plaintext attachment contents.
*/
public InputStream retrieveArchivedAttachment(@Nonnull BackupKey.KeyMaterial<MediaId> archivedMediaKeyMaterial,
@Nonnull Map<String, String> readCredentialHeaders,
@Nonnull File archiveDestination,
@Nonnull SignalServiceAttachmentPointer pointer,
@Nonnull File attachmentDestination,
long maxSizeBytes,
@Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException
{
if (pointer.getDigest().isEmpty()) {
throw new InvalidMessageException("No attachment digest!");
}
socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener);
long originalCipherLength = pointer.getSize()
.filter(s -> s > 0)
.map(s -> AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(s)))
.orElse(0L);
try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMedia(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) {
try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) {
StreamUtil.copy(backupDecrypted, fos);
}
}
return AttachmentCipherInputStream.createForAttachment(attachmentDestination,
pointer.getSize().orElse(0),
pointer.getKey(),
pointer.getDigest().get(),
null,
0);
}
public void retrieveBackup(int cdnNumber, Map<String, String> headers, String cdnPath, File destination, ProgressListener listener) throws MissingConfigurationException, IOException {
socket.retrieveBackup(cdnNumber, headers, cdnPath, destination, 1_000_000_000L, listener);
}
public InputStream retrieveSticker(byte[] packId, byte[] packKey, int stickerId) public InputStream retrieveSticker(byte[] packId, byte[] packKey, int stickerId)
throws IOException, InvalidMessageException throws IOException, InvalidMessageException
{ {
@@ -841,7 +841,7 @@ public class SignalServiceMessageSender {
Pair<Long, AttachmentDigest> attachmentIdAndDigest = socket.uploadAttachment(attachmentData, v2UploadAttributes); Pair<Long, AttachmentDigest> attachmentIdAndDigest = socket.uploadAttachment(attachmentData, v2UploadAttributes);
return new SignalServiceAttachmentPointer(0, return new SignalServiceAttachmentPointer(0,
new SignalServiceAttachmentRemoteId(attachmentIdAndDigest.first()), new SignalServiceAttachmentRemoteId.V2(attachmentIdAndDigest.first()),
attachment.getContentType(), attachment.getContentType(),
attachmentKey, attachmentKey,
Optional.of(Util.toIntExact(attachment.getLength())), Optional.of(Util.toIntExact(attachment.getLength())),
@@ -882,7 +882,7 @@ public class SignalServiceMessageSender {
private SignalServiceAttachmentPointer uploadAttachmentV4(SignalServiceAttachmentStream attachment, byte[] attachmentKey, PushAttachmentData attachmentData) throws IOException { private SignalServiceAttachmentPointer uploadAttachmentV4(SignalServiceAttachmentStream attachment, byte[] attachmentKey, PushAttachmentData attachmentData) throws IOException {
AttachmentDigest digest = socket.uploadAttachment(attachmentData); AttachmentDigest digest = socket.uploadAttachment(attachmentData);
return new SignalServiceAttachmentPointer(attachmentData.getResumableUploadSpec().getCdnNumber(), return new SignalServiceAttachmentPointer(attachmentData.getResumableUploadSpec().getCdnNumber(),
new SignalServiceAttachmentRemoteId(attachmentData.getResumableUploadSpec().getCdnKey()), new SignalServiceAttachmentRemoteId.V4(attachmentData.getResumableUploadSpec().getCdnKey()),
attachment.getContentType(), attachment.getContentType(),
attachmentKey, attachmentKey,
Optional.of(Util.toIntExact(attachment.getLength())), Optional.of(Util.toIntExact(attachment.getLength())),
@@ -55,6 +55,15 @@ class ArchiveApi(
} }
} }
fun getCdnReadCredentials(backupKey: BackupKey, serviceCredential: ArchiveServiceCredential): NetworkResult<GetArchiveCdnCredentialsResponse> {
return NetworkResult.fromFetch {
val zkCredential = getZkCredential(backupKey, serviceCredential)
val presentationData = CredentialPresentationData.from(backupKey, zkCredential, backupServerPublicParams)
pushServiceSocket.getArchiveCdnReadCredentials(presentationData.toArchiveCredentialPresentation())
}
}
/** /**
* Ensures that you reserve a backupId on the service. This must be done before any other * Ensures that you reserve a backupId on the service. This must be done before any other
* backup-related calls. You only need to do it once, but repeated calls are safe. * backup-related calls. You only need to do it once, but repeated calls are safe.
@@ -16,6 +16,8 @@ data class ArchiveGetBackupInfoResponse(
@JsonProperty @JsonProperty
val backupDir: String?, val backupDir: String?,
@JsonProperty @JsonProperty
val mediaDir: String?,
@JsonProperty
val backupName: String?, val backupName: String?,
@JsonProperty @JsonProperty
val usedSpace: Long? val usedSpace: Long?
@@ -0,0 +1,15 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.archive
import com.fasterxml.jackson.annotation.JsonProperty
/**
* Get response with headers to use to read from archive cdn.
*/
class GetArchiveCdnCredentialsResponse(
@JsonProperty val headers: Map<String, String>
)
@@ -5,6 +5,9 @@
package org.whispersystems.signalservice.api.backup package org.whispersystems.signalservice.api.backup
import org.signal.core.util.Base64
import java.security.MessageDigest
/** /**
* Safe typing around a backupId, which is a 16-byte array. * Safe typing around a backupId, which is a 16-byte array.
*/ */
@@ -14,4 +17,9 @@ value class BackupId(val value: ByteArray) {
init { init {
require(value.size == 16) { "BackupId must be 16 bytes!" } require(value.size == 16) { "BackupId must be 16 bytes!" }
} }
/** Encode backup-id for use in a URL/request */
fun encode(): String {
return Base64.encodeUrlSafeWithPadding(MessageDigest.getInstance("SHA-256").digest(value).copyOfRange(0, 16))
}
} }
@@ -16,10 +16,14 @@ class BackupKey(val value: ByteArray) {
require(value.size == 32) { "Backup key must be 32 bytes!" } require(value.size == 32) { "Backup key must be 32 bytes!" }
} }
fun deriveSecrets(aci: ACI): KeyMaterial<BackupId> { fun deriveBackupId(aci: ACI): BackupId {
val backupId = BackupId( return BackupId(
HKDF.deriveSecrets(this.value, aci.toByteArray(), "20231003_Signal_Backups_GenerateBackupId".toByteArray(), 16) HKDF.deriveSecrets(this.value, aci.toByteArray(), "20231003_Signal_Backups_GenerateBackupId".toByteArray(), 16)
) )
}
fun deriveSecrets(aci: ACI): KeyMaterial<BackupId> {
val backupId = deriveBackupId(aci)
val extendedKey = HKDF.deriveSecrets(this.value, backupId.value, "20231003_Signal_Backups_EncryptMessageBackup".toByteArray(), 80) val extendedKey = HKDF.deriveSecrets(this.value, backupId.value, "20231003_Signal_Backups_EncryptMessageBackup".toByteArray(), 80)
@@ -31,13 +35,15 @@ class BackupKey(val value: ByteArray) {
) )
} }
fun deriveMediaId(dataHash: ByteArray): MediaId { fun deriveMediaId(mediaName: MediaName): MediaId {
return MediaId(HKDF.deriveSecrets(value, dataHash, "Media ID".toByteArray(), 15)) return MediaId(HKDF.deriveSecrets(value, mediaName.toByteArray(), "Media ID".toByteArray(), 15))
} }
fun deriveMediaSecrets(dataHash: ByteArray): KeyMaterial<MediaId> { fun deriveMediaSecrets(mediaName: MediaName): KeyMaterial<MediaId> {
val mediaId = deriveMediaId(dataHash) return deriveMediaSecrets(deriveMediaId(mediaName))
}
fun deriveMediaSecrets(mediaId: MediaId): KeyMaterial<MediaId> {
val extendedKey = HKDF.deriveSecrets(this.value, mediaId.value, "20231003_Signal_Backups_EncryptMedia".toByteArray(), 80) val extendedKey = HKDF.deriveSecrets(this.value, mediaId.value, "20231003_Signal_Backups_EncryptMedia".toByteArray(), 80)
return KeyMaterial( return KeyMaterial(
@@ -53,5 +59,17 @@ class BackupKey(val value: ByteArray) {
val macKey: ByteArray, val macKey: ByteArray,
val cipherKey: ByteArray, val cipherKey: ByteArray,
val iv: ByteArray val iv: ByteArray
) ) {
companion object {
@JvmStatic
fun forMedia(id: ByteArray, keyMac: ByteArray, iv: ByteArray): KeyMaterial<MediaId> {
return KeyMaterial(
MediaId(id),
keyMac.copyOfRange(32, 64),
keyMac.copyOfRange(0, 32),
iv
)
}
}
}
} }
@@ -13,11 +13,14 @@ import org.signal.core.util.Base64
@JvmInline @JvmInline
value class MediaId(val value: ByteArray) { value class MediaId(val value: ByteArray) {
constructor(mediaId: String) : this(Base64.decode(mediaId))
init { init {
require(value.size == 15) { "MediaId must be 15 bytes!" } require(value.size == 15) { "MediaId must be 15 bytes!" }
} }
override fun toString(): String { /** Encode media-id for use in a URL/request */
fun encode(): String {
return Base64.encodeUrlSafeWithPadding(value) return Base64.encodeUrlSafeWithPadding(value)
} }
} }
@@ -0,0 +1,24 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.backup
import org.signal.core.util.Base64
/**
* Represent a media name for the various types of media that can be archived.
*/
@JvmInline
value class MediaName(val name: String) {
companion object {
fun fromDigest(digest: ByteArray) = MediaName(Base64.encodeWithoutPadding(digest))
fun fromDigestForThumbnail(digest: ByteArray) = MediaName("${Base64.encodeWithoutPadding(digest)}_thumbnail")
}
fun toByteArray(): ByteArray {
return name.toByteArray()
}
}
@@ -10,7 +10,9 @@ import org.signal.libsignal.protocol.InvalidMacException;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice; import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice;
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream; import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream;
import org.signal.libsignal.protocol.kdf.HKDFv3; import org.signal.libsignal.protocol.kdf.HKDF;
import org.whispersystems.signalservice.api.backup.BackupKey;
import org.whispersystems.signalservice.api.backup.MediaId;
import org.whispersystems.signalservice.internal.util.ContentLengthInputStream; import org.whispersystems.signalservice.internal.util.ContentLengthInputStream;
import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.signalservice.internal.util.Util;
@@ -26,6 +28,8 @@ import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.Arrays; import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.crypto.BadPaddingException; import javax.crypto.BadPaddingException;
import javax.crypto.Cipher; import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException; import javax.crypto.IllegalBlockSizeException;
@@ -47,9 +51,10 @@ public class AttachmentCipherInputStream extends FilterInputStream {
private static final int CIPHER_KEY_SIZE = 32; private static final int CIPHER_KEY_SIZE = 32;
private static final int MAC_KEY_SIZE = 32; private static final int MAC_KEY_SIZE = 32;
private Cipher cipher; private final Cipher cipher;
private final long totalDataSize;
private boolean done; private boolean done;
private long totalDataSize;
private long totalRead; private long totalRead;
private byte[] overflowBuffer; private byte[] overflowBuffer;
@@ -102,11 +107,43 @@ public class AttachmentCipherInputStream extends FilterInputStream {
} }
} }
/**
* Decrypt archived media to it's original attachment encrypted blob.
*/
public static InputStream createForArchivedMedia(BackupKey.KeyMaterial<MediaId> archivedMediaKeyMaterial, File file, long originalCipherTextLength)
throws InvalidMessageException, IOException
{
try {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(archivedMediaKeyMaterial.getMacKey(), "HmacSHA256"));
if (file.length() <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
try (FileInputStream macVerificationStream = new FileInputStream(file)) {
verifyMac(macVerificationStream, file.length(), mac, null);
}
InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength());
if (originalCipherTextLength != 0) {
inputStream = new ContentLengthInputStream(inputStream, originalCipherTextLength);
}
return inputStream;
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
} catch (InvalidMacException e) {
throw new InvalidMessageException(e);
}
}
public static InputStream createForStickerData(byte[] data, byte[] packKey) public static InputStream createForStickerData(byte[] data, byte[] packKey)
throws InvalidMessageException, IOException throws InvalidMessageException, IOException
{ {
try { try {
byte[] combinedKeyMaterial = new HKDFv3().deriveSecrets(packKey, "Sticker Pack".getBytes(), 64); byte[] combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".getBytes(), 64);
byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE); byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE);
Mac mac = Mac.getInstance("HmacSHA256"); Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(parts[1], "HmacSHA256")); mac.init(new SecretKeySpec(parts[1], "HmacSHA256"));
@@ -159,12 +196,12 @@ public class AttachmentCipherInputStream extends FilterInputStream {
} }
@Override @Override
public int read(byte[] buffer) throws IOException { public int read(@Nonnull byte[] buffer) throws IOException {
return read(buffer, 0, buffer.length); return read(buffer, 0, buffer.length);
} }
@Override @Override
public int read(byte[] buffer, int offset, int length) throws IOException { public int read(@Nonnull byte[] buffer, int offset, int length) throws IOException {
if (totalRead != totalDataSize) { if (totalRead != totalDataSize) {
return readIncremental(buffer, offset, length); return readIncremental(buffer, offset, length);
} else if (!done) { } else if (!done) {
@@ -256,7 +293,7 @@ public class AttachmentCipherInputStream extends FilterInputStream {
} }
} }
private static void verifyMac(InputStream inputStream, long length, Mac mac, byte[] theirDigest) private static void verifyMac(@Nonnull InputStream inputStream, long length, @Nonnull Mac mac, @Nullable byte[] theirDigest)
throws InvalidMacException throws InvalidMacException
{ {
try { try {
@@ -1,67 +0,0 @@
package org.whispersystems.signalservice.api.messages;
import org.whispersystems.signalservice.api.InvalidMessageStructureException;
import org.whispersystems.signalservice.internal.push.AttachmentPointer;
import java.util.Optional;
/**
* Represents a signal service attachment identifier. This can be either a CDN key or a long, but
* not both at once. Attachments V2 used a long as an attachment identifier. This lacks sufficient
* entropy to reduce the likelihood of any two uploads going to the same location within a 30-day
* window. Attachments V3 uses an opaque string as an attachment identifier which provides more
* flexibility in the amount of entropy present.
*/
public final class SignalServiceAttachmentRemoteId {
private final Optional<Long> v2;
private final Optional<String> v3;
public SignalServiceAttachmentRemoteId(long v2) {
this.v2 = Optional.of(v2);
this.v3 = Optional.empty();
}
public SignalServiceAttachmentRemoteId(String v3) {
this.v2 = Optional.empty();
this.v3 = Optional.of(v3);
}
public Optional<Long> getV2() {
return v2;
}
public Optional<String> getV3() {
return v3;
}
@Override
public String toString() {
if (v2.isPresent()) {
return v2.get().toString();
} else {
return v3.get();
}
}
public static SignalServiceAttachmentRemoteId from(AttachmentPointer attachmentPointer) throws InvalidMessageStructureException {
if (attachmentPointer.cdnKey != null) {
return new SignalServiceAttachmentRemoteId(attachmentPointer.cdnKey);
} else if (attachmentPointer.cdnId != null && attachmentPointer.cdnId > 0) {
return new SignalServiceAttachmentRemoteId(attachmentPointer.cdnId);
} else {
throw new InvalidMessageStructureException("AttachmentPointer CDN location not set");
}
}
/**
* Guesses that strings which contain values parseable to {@code long} should use an id-based
* CDN path. Otherwise, use key-based CDN path.
*/
public static SignalServiceAttachmentRemoteId from(String string) {
try {
return new SignalServiceAttachmentRemoteId(Long.parseLong(string));
} catch (NumberFormatException e) {
return new SignalServiceAttachmentRemoteId(string);
}
}
}
@@ -0,0 +1,54 @@
package org.whispersystems.signalservice.api.messages
import org.whispersystems.signalservice.api.InvalidMessageStructureException
import org.whispersystems.signalservice.internal.push.AttachmentPointer
/**
* Represents a signal service attachment identifier. This can be either a CDN key or a long, but
* not both at once. Attachments V2 used a long as an attachment identifier. This lacks sufficient
* entropy to reduce the likelihood of any two uploads going to the same location within a 30-day
* window. Attachments V4 (backwards compatible with V3) uses an opaque string as an attachment
* identifier which provides more flexibility in the amount of entropy present.
*/
sealed interface SignalServiceAttachmentRemoteId {
object S3 : SignalServiceAttachmentRemoteId {
override fun toString() = ""
}
data class V2(val cdnId: Long) : SignalServiceAttachmentRemoteId {
override fun toString() = cdnId.toString()
}
data class V4(val cdnKey: String) : SignalServiceAttachmentRemoteId {
override fun toString() = cdnKey
}
data class Backup(val backupDir: String, val mediaDir: String, val mediaId: String) : SignalServiceAttachmentRemoteId {
override fun toString() = mediaId
}
companion object {
@JvmStatic
@Throws(InvalidMessageStructureException::class)
fun from(attachmentPointer: AttachmentPointer): SignalServiceAttachmentRemoteId {
return if (attachmentPointer.cdnKey != null) {
V4(attachmentPointer.cdnKey)
} else if (attachmentPointer.cdnId != null && attachmentPointer.cdnId > 0) {
V2(attachmentPointer.cdnId)
} else {
throw InvalidMessageStructureException("AttachmentPointer CDN location not set")
}
}
/**
* Guesses that strings which contain values parseable to `long` should use an id-based
* CDN path. Otherwise, use key-based CDN path.
*/
@JvmStatic
fun from(string: String): SignalServiceAttachmentRemoteId {
return string.toLongOrNull()?.let { V2(it) } ?: V4(string)
}
}
}
@@ -56,12 +56,12 @@ public final class AttachmentPointerUtil {
builder.incrementalMacChunkSize(attachment.getIncrementalMacChunkSize()); builder.incrementalMacChunkSize(attachment.getIncrementalMacChunkSize());
} }
if (attachment.getRemoteId().getV2().isPresent()) { if (attachment.getRemoteId() instanceof SignalServiceAttachmentRemoteId.V2) {
builder.cdnId(attachment.getRemoteId().getV2().get()); builder.cdnId(((SignalServiceAttachmentRemoteId.V2) attachment.getRemoteId()).getCdnId());
} }
if (attachment.getRemoteId().getV3().isPresent()) { if (attachment.getRemoteId() instanceof SignalServiceAttachmentRemoteId.V4) {
builder.cdnKey(attachment.getRemoteId().getV3().get()); builder.cdnKey(((SignalServiceAttachmentRemoteId.V4) attachment.getRemoteId()).getCdnKey());
} }
if (attachment.getFileName().isPresent()) { if (attachment.getFileName().isPresent()) {
@@ -58,6 +58,7 @@ import org.whispersystems.signalservice.api.archive.ArchiveSetPublicKeyRequest;
import org.whispersystems.signalservice.api.archive.BatchArchiveMediaRequest; import org.whispersystems.signalservice.api.archive.BatchArchiveMediaRequest;
import org.whispersystems.signalservice.api.archive.BatchArchiveMediaResponse; import org.whispersystems.signalservice.api.archive.BatchArchiveMediaResponse;
import org.whispersystems.signalservice.api.archive.DeleteArchivedMediaRequest; import org.whispersystems.signalservice.api.archive.DeleteArchivedMediaRequest;
import org.whispersystems.signalservice.api.archive.GetArchiveCdnCredentialsResponse;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess; import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
import org.whispersystems.signalservice.api.groupsv2.CredentialResponse; import org.whispersystems.signalservice.api.groupsv2.CredentialResponse;
import org.whispersystems.signalservice.api.groupsv2.GroupsV2AuthorizationString; import org.whispersystems.signalservice.api.groupsv2.GroupsV2AuthorizationString;
@@ -319,6 +320,7 @@ public class PushServiceSocket {
private static final String ARCHIVE_MEDIA_LIST = "/v1/archives/media?limit=%d"; private static final String ARCHIVE_MEDIA_LIST = "/v1/archives/media?limit=%d";
private static final String ARCHIVE_MEDIA_BATCH = "/v1/archives/media/batch"; private static final String ARCHIVE_MEDIA_BATCH = "/v1/archives/media/batch";
private static final String ARCHIVE_MEDIA_DELETE = "/v1/archives/media/delete"; private static final String ARCHIVE_MEDIA_DELETE = "/v1/archives/media/delete";
private static final String ARCHIVE_MEDIA_DOWNLOAD_PATH = "backups/%s/%s/%s";
private static final String CALL_LINK_CREATION_AUTH = "/v1/call-link/create-auth"; private static final String CALL_LINK_CREATION_AUTH = "/v1/call-link/create-auth";
private static final String SERVER_DELIVERED_TIMESTAMP_HEADER = "X-Signal-Timestamp"; private static final String SERVER_DELIVERED_TIMESTAMP_HEADER = "X-Signal-Timestamp";
@@ -585,6 +587,16 @@ public class PushServiceSocket {
return JsonUtil.fromJson(response, ArchiveMessageBackupUploadFormResponse.class); return JsonUtil.fromJson(response, ArchiveMessageBackupUploadFormResponse.class);
} }
/**
* Copy and re-encrypt media from the attachments cdn into the backup cdn.
*/
public GetArchiveCdnCredentialsResponse getArchiveCdnReadCredentials(@Nonnull ArchiveCredentialPresentation credentialPresentation) throws IOException {
Map<String, String> headers = credentialPresentation.toHeaders();
String response = makeServiceRequestWithoutAuthentication(ARCHIVE_READ_CREDENTIALS, "GET", null, headers, NO_HANDLER);
return JsonUtil.fromJson(response, GetArchiveCdnCredentialsResponse.class);
}
public VerifyAccountResponse changeNumber(@Nonnull ChangePhoneNumberRequest changePhoneNumberRequest) public VerifyAccountResponse changeNumber(@Nonnull ChangePhoneNumberRequest changePhoneNumberRequest)
throws IOException throws IOException
@@ -919,16 +931,27 @@ public class PushServiceSocket {
}, Optional.empty()); }, Optional.empty());
} }
public void retrieveAttachment(int cdnNumber, SignalServiceAttachmentRemoteId cdnPath, File destination, long maxSizeBytes, ProgressListener listener) public void retrieveBackup(int cdnNumber, Map<String, String> headers, String cdnPath, File destination, long maxSizeBytes, ProgressListener listener)
throws MissingConfigurationException, IOException
{
downloadFromCdn(destination, cdnNumber, headers, cdnPath, maxSizeBytes, listener);
}
public void retrieveAttachment(int cdnNumber, Map<String, String> headers, SignalServiceAttachmentRemoteId cdnPath, File destination, long maxSizeBytes, ProgressListener listener)
throws IOException, MissingConfigurationException throws IOException, MissingConfigurationException
{ {
final String path; final String path;
if (cdnPath.getV2().isPresent()) { if (cdnPath instanceof SignalServiceAttachmentRemoteId.V2) {
path = String.format(Locale.US, ATTACHMENT_ID_DOWNLOAD_PATH, cdnPath.getV2().get()); path = String.format(Locale.US, ATTACHMENT_ID_DOWNLOAD_PATH, ((SignalServiceAttachmentRemoteId.V2) cdnPath).getCdnId());
} else if (cdnPath instanceof SignalServiceAttachmentRemoteId.V4) {
path = String.format(Locale.US, ATTACHMENT_KEY_DOWNLOAD_PATH, ((SignalServiceAttachmentRemoteId.V4) cdnPath).getCdnKey());
} else if (cdnPath instanceof SignalServiceAttachmentRemoteId.Backup) {
SignalServiceAttachmentRemoteId.Backup backupCdnId = (SignalServiceAttachmentRemoteId.Backup) cdnPath;
path = String.format(Locale.US, ARCHIVE_MEDIA_DOWNLOAD_PATH, backupCdnId.getBackupDir(), backupCdnId.getMediaDir(), backupCdnId.getMediaId());
} else { } else {
path = String.format(Locale.US, ATTACHMENT_KEY_DOWNLOAD_PATH, cdnPath.getV3().get()); throw new IllegalArgumentException("Invalid cdnPath type: " + cdnPath.getClass().getSimpleName());
} }
downloadFromCdn(destination, cdnNumber, path, maxSizeBytes, listener); downloadFromCdn(destination, cdnNumber, headers, path, maxSizeBytes, listener);
} }
public byte[] retrieveSticker(byte[] packId, int stickerId) public byte[] retrieveSticker(byte[] packId, int stickerId)
@@ -937,7 +960,7 @@ public class PushServiceSocket {
ByteArrayOutputStream output = new ByteArrayOutputStream(); ByteArrayOutputStream output = new ByteArrayOutputStream();
try { try {
downloadFromCdn(output, 0, 0, String.format(Locale.US, STICKER_PATH, hexPackId, stickerId), 1024 * 1024, null); downloadFromCdn(output, 0, 0, Collections.emptyMap(), String.format(Locale.US, STICKER_PATH, hexPackId, stickerId), 1024 * 1024, null);
} catch (MissingConfigurationException e) { } catch (MissingConfigurationException e) {
throw new AssertionError(e); throw new AssertionError(e);
} }
@@ -951,7 +974,7 @@ public class PushServiceSocket {
ByteArrayOutputStream output = new ByteArrayOutputStream(); ByteArrayOutputStream output = new ByteArrayOutputStream();
try { try {
downloadFromCdn(output, 0, 0, String.format(STICKER_MANIFEST_PATH, hexPackId), 1024 * 1024, null); downloadFromCdn(output, 0, 0, Collections.emptyMap(), String.format(STICKER_MANIFEST_PATH, hexPackId), 1024 * 1024, null);
} catch (MissingConfigurationException e) { } catch (MissingConfigurationException e) {
throw new AssertionError(e); throw new AssertionError(e);
} }
@@ -1029,7 +1052,7 @@ public class PushServiceSocket {
throws IOException throws IOException
{ {
try { try {
downloadFromCdn(destination, 0, path, maxSizeBytes, null); downloadFromCdn(destination, 0, Collections.emptyMap(), path, maxSizeBytes, null);
} catch (MissingConfigurationException e) { } catch (MissingConfigurationException e) {
throw new AssertionError(e); throw new AssertionError(e);
} }
@@ -1577,15 +1600,15 @@ public class PushServiceSocket {
} }
} }
private void downloadFromCdn(File destination, int cdnNumber, String path, long maxSizeBytes, ProgressListener listener) private void downloadFromCdn(File destination, int cdnNumber, Map<String, String> headers, String path, long maxSizeBytes, ProgressListener listener)
throws IOException, MissingConfigurationException throws IOException, MissingConfigurationException
{ {
try (FileOutputStream outputStream = new FileOutputStream(destination, true)) { try (FileOutputStream outputStream = new FileOutputStream(destination, true)) {
downloadFromCdn(outputStream, destination.length(), cdnNumber, path, maxSizeBytes, listener); downloadFromCdn(outputStream, destination.length(), cdnNumber, headers, path, maxSizeBytes, listener);
} }
} }
private void downloadFromCdn(OutputStream outputStream, long offset, int cdnNumber, String path, long maxSizeBytes, ProgressListener listener) private void downloadFromCdn(OutputStream outputStream, long offset, int cdnNumber, Map<String, String> headers, String path, long maxSizeBytes, ProgressListener listener)
throws PushNetworkException, NonSuccessfulResponseCodeException, MissingConfigurationException { throws PushNetworkException, NonSuccessfulResponseCodeException, MissingConfigurationException {
ConnectionHolder[] cdnNumberClients = cdnClientsMap.get(cdnNumber); ConnectionHolder[] cdnNumberClients = cdnClientsMap.get(cdnNumber);
if (cdnNumberClients == null) { if (cdnNumberClients == null) {
@@ -1604,6 +1627,10 @@ public class PushServiceSocket {
request.addHeader("Host", connectionHolder.getHostHeader().get()); request.addHeader("Host", connectionHolder.getHostHeader().get());
} }
for (Map.Entry<String, String> header : headers.entrySet()) {
request.addHeader(header.getKey(), header.getValue());
}
if (offset > 0) { if (offset > 0) {
Log.i(TAG, "Starting download from CDN with offset " + offset); Log.i(TAG, "Starting download from CDN with offset " + offset);
request.addHeader("Range", "bytes=" + offset + "-"); request.addHeader("Range", "bytes=" + offset + "-");
@@ -300,7 +300,7 @@ public class WebSocketConnection extends WebSocketListener {
OutgoingRequest listener = outgoingRequests.remove(message.response.id); OutgoingRequest listener = outgoingRequests.remove(message.response.id);
if (listener != null) { if (listener != null) {
listener.onSuccess(new WebsocketResponse(message.response.status, listener.onSuccess(new WebsocketResponse(message.response.status,
new String(message.response.body.toByteArray()), message.response.body == null ? "" : new String(message.response.body.toByteArray()),
message.response.headers, message.response.headers,
!credentialsProvider.isPresent())); !credentialsProvider.isPresent()));
if (message.response.status >= 400) { if (message.response.status >= 400) {
@@ -6,6 +6,8 @@ import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice; import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice;
import org.signal.libsignal.protocol.incrementalmac.InvalidMacException; import org.signal.libsignal.protocol.incrementalmac.InvalidMacException;
import org.signal.libsignal.protocol.kdf.HKDFv3; import org.signal.libsignal.protocol.kdf.HKDFv3;
import org.whispersystems.signalservice.api.backup.BackupKey;
import org.whispersystems.signalservice.api.backup.MediaId;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream; import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory; import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory;
import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.signalservice.internal.util.Util;
@@ -88,6 +90,62 @@ public final class AttachmentCipherTest {
assertTrue(hitCorrectException); assertTrue(hitCorrectException);
} }
@Test
public void archive_encryptDecrypt() throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64);
BackupKey.KeyMaterial<MediaId> keyMaterial = BackupKey.KeyMaterial.forMedia(Util.getSecretBytes(15), key, Util.getSecretBytes(16));
byte[] plaintextInput = "Peter Parker".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
byte[] plaintextOutput = readInputStreamFully(inputStream);
assertArrayEquals(plaintextInput, plaintextOutput);
cipherFile.delete();
}
@Test
public void archive_encryptDecryptEmpty() throws IOException, InvalidMessageException {
byte[] key = Util.getSecretBytes(64);
BackupKey.KeyMaterial<MediaId> keyMaterial = BackupKey.KeyMaterial.forMedia(Util.getSecretBytes(15), key, Util.getSecretBytes(16));
byte[] plaintextInput = "".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
File cipherFile = writeToFile(encryptResult.ciphertext);
InputStream inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
byte[] plaintextOutput = readInputStreamFully(inputStream);
assertArrayEquals(plaintextInput, plaintextOutput);
cipherFile.delete();
}
@Test
public void archive_decryptFailOnBadKey() throws IOException {
File cipherFile = null;
boolean hitCorrectException = false;
try {
byte[] key = Util.getSecretBytes(64);
byte[] badKey = Util.getSecretBytes(64);
BackupKey.KeyMaterial<MediaId> keyMaterial = BackupKey.KeyMaterial.forMedia(Util.getSecretBytes(15), badKey, Util.getSecretBytes(16));
byte[] plaintextInput = "Gwen Stacy".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, false);
cipherFile = writeToFile(encryptResult.ciphertext);
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
} catch (InvalidMessageException e) {
hitCorrectException = true;
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
assertTrue(hitCorrectException);
}
@Test @Test
public void attachment_decryptFailOnBadDigest() throws IOException { public void attachment_decryptFailOnBadDigest() throws IOException {
File cipherFile = null; File cipherFile = null;
@@ -184,6 +242,44 @@ public final class AttachmentCipherTest {
} }
} }
@Test
public void archive_encryptDecryptPaddedContent() throws IOException, InvalidMessageException {
int[] lengths = { 531, 600, 724, 1019, 1024 };
for (int length : lengths) {
byte[] plaintextInput = new byte[length];
for (int i = 0; i < length; i++) {
plaintextInput[i] = (byte) 0x97;
}
byte[] key = Util.getSecretBytes(64);
byte[] iv = Util.getSecretBytes(16);
ByteArrayInputStream inputStream = new ByteArrayInputStream(plaintextInput);
InputStream paddedInputStream = new PaddingInputStream(inputStream, length);
ByteArrayOutputStream destinationOutputStream = new ByteArrayOutputStream();
DigestingOutputStream encryptingOutputStream = new AttachmentCipherOutputStreamFactory(key, iv).createFor(destinationOutputStream);
Util.copy(paddedInputStream, encryptingOutputStream);
encryptingOutputStream.flush();
encryptingOutputStream.close();
byte[] encryptedData = destinationOutputStream.toByteArray();
File cipherFile = writeToFile(encryptedData);
BackupKey.KeyMaterial<MediaId> keyMaterial = BackupKey.KeyMaterial.forMedia(Util.getSecretBytes(15), key, Util.getSecretBytes(16));
InputStream decryptedStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, length);
byte[] plaintextOutput = readInputStreamFully(decryptedStream);
assertArrayEquals(plaintextInput, plaintextOutput);
cipherFile.delete();
}
}
@Test @Test
public void attachment_decryptFailOnNullDigest() throws IOException { public void attachment_decryptFailOnNullDigest() throws IOException {
File cipherFile = null; File cipherFile = null;
@@ -237,6 +333,35 @@ public final class AttachmentCipherTest {
assertTrue(hitCorrectException); assertTrue(hitCorrectException);
} }
@Test
public void archive_decryptFailOnBadMac() throws IOException {
File cipherFile = null;
boolean hitCorrectException = false;
try {
byte[] key = Util.getSecretBytes(64);
byte[] plaintextInput = "Uncle Ben".getBytes();
EncryptResult encryptResult = encryptData(plaintextInput, key, true);
byte[] badMacCiphertext = Arrays.copyOf(encryptResult.ciphertext, encryptResult.ciphertext.length);
badMacCiphertext[badMacCiphertext.length - 1] += 1;
cipherFile = writeToFile(badMacCiphertext);
BackupKey.KeyMaterial<MediaId> keyMaterial = BackupKey.KeyMaterial.forMedia(Util.getSecretBytes(15), key, Util.getSecretBytes(16));
AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.length);
fail();
} catch (InvalidMessageException e) {
hitCorrectException = true;
} finally {
if (cipherFile != null) {
cipherFile.delete();
}
}
assertTrue(hitCorrectException);
}
@Test @Test
public void sticker_encryptDecrypt() throws IOException, InvalidMessageException { public void sticker_encryptDecrypt() throws IOException, InvalidMessageException {
assumeLibSignalSupportedOnOS(); assumeLibSignalSupportedOnOS();