Add polls to backups.

This commit is contained in:
Michelle Tang
2025-10-14 11:41:47 -04:00
committed by Cody Henthorne
parent a2aabeaad2
commit 525175f04a
14 changed files with 236 additions and 54 deletions

View File

@@ -119,6 +119,10 @@ object ExportSkips {
return log(sentTimestamp, "Failed to parse thread merge event.")
}
fun pollTerminateIsEmpty(sentTimestamp: Long): String {
return log(sentTimestamp, "Poll terminate update was empty.")
}
fun individualChatUpdateInWrongTypeOfChat(sentTimestamp: Long): String {
return log(sentTimestamp, "A chat update that only makes sense for individual chats was found in a different kind of chat.")
}

View File

@@ -50,6 +50,8 @@ import org.thoughtcrime.securesms.backup.v2.proto.IndividualCall
import org.thoughtcrime.securesms.backup.v2.proto.LearnedProfileChatUpdate
import org.thoughtcrime.securesms.backup.v2.proto.MessageAttachment
import org.thoughtcrime.securesms.backup.v2.proto.PaymentNotification
import org.thoughtcrime.securesms.backup.v2.proto.Poll
import org.thoughtcrime.securesms.backup.v2.proto.PollTerminateUpdate
import org.thoughtcrime.securesms.backup.v2.proto.ProfileChangeChatUpdate
import org.thoughtcrime.securesms.backup.v2.proto.Quote
import org.thoughtcrime.securesms.backup.v2.proto.Reaction
@@ -93,6 +95,7 @@ import org.thoughtcrime.securesms.mms.PartAuthority
import org.thoughtcrime.securesms.mms.QuoteModel
import org.thoughtcrime.securesms.payments.FailureReason
import org.thoughtcrime.securesms.payments.State
import org.thoughtcrime.securesms.polls.PollRecord
import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.util.JsonUtils
import org.thoughtcrime.securesms.util.MediaUtil
@@ -371,6 +374,22 @@ class ChatItemArchiveExporter(
transformTimer.emit("story")
}
MessageTypes.isPollTerminate(record.type) -> {
val pollTerminateUpdate = record.toRemotePollTerminateUpdate()
if (pollTerminateUpdate == null) {
Log.w(TAG, ExportSkips.pollTerminateIsEmpty(record.dateSent))
continue
}
builder.updateMessage = ChatUpdateMessage(pollTerminate = pollTerminateUpdate)
transformTimer.emit("poll-terminate")
}
extraData.pollsById[record.id] != null -> {
val poll = extraData.pollsById[record.id]!!
builder.poll = poll.toRemotePollMessage()
transformTimer.emit("poll")
}
else -> {
val attachments = extraData.attachmentsById[record.id]
val sticker = attachments?.firstOrNull { dbAttachment -> dbAttachment.isSticker }
@@ -471,16 +490,24 @@ class ChatItemArchiveExporter(
}
}
val pollsFuture = executor.submitTyped {
extraDataTimer.timeEvent("polls") {
db.pollTable.getPollsForMessages(messageIds)
}
}
val mentionsResult = mentionsFuture.get()
val reactionsResult = reactionsFuture.get()
val attachmentsResult = attachmentsFuture.get()
val groupReceiptsResult = groupReceiptsFuture.get()
val pollsResult = pollsFuture.get()
return ExtraMessageData(
mentionsById = mentionsResult,
reactionsById = reactionsResult,
attachmentsById = attachmentsResult,
groupReceiptsById = groupReceiptsResult
groupReceiptsById = groupReceiptsResult,
pollsById = pollsResult
)
}
}
@@ -783,6 +810,14 @@ private fun BackupMessageRecord.toRemotePaymentNotificationUpdate(db: SignalData
}
}
private fun BackupMessageRecord.toRemotePollTerminateUpdate(): PollTerminateUpdate? {
val pollTerminate = this.messageExtras?.pollTerminate ?: return null
return PollTerminateUpdate(
targetSentTimestamp = pollTerminate.targetTimestamp,
question = pollTerminate.question
)
}
private fun BackupMessageRecord.toRemoteSharedContact(attachments: List<DatabaseAttachment>?): Contact? {
if (this.sharedContacts.isNullOrEmpty()) {
return null
@@ -1131,6 +1166,25 @@ private fun BackupMessageRecord.toRemoteGiftBadgeUpdate(): BackupGiftBadge? {
)
}
private fun PollRecord.toRemotePollMessage(): Poll {
return Poll(
question = this.question,
allowMultiple = this.allowMultipleVotes,
hasEnded = this.hasEnded,
options = this.pollOptions.map { option ->
Poll.PollOption(
option = option.text,
votes = option.voters.map { voter ->
Poll.PollOption.PollVote(
voterId = voter.id,
voteCount = voter.voteCount
)
}
)
}
)
}
private fun DatabaseAttachment.toRemoteStickerMessage(sentTimestamp: Long, reactions: List<ReactionRecord>?): StickerMessage? {
val stickerLocator = this.stickerLocator!!
@@ -1491,7 +1545,8 @@ private fun Long.isDirectionlessType(): Boolean {
MessageTypes.isGroupCall(this) ||
MessageTypes.isGroupUpdate(this) ||
MessageTypes.isGroupV1MigrationEvent(this) ||
MessageTypes.isGroupQuit(this)
MessageTypes.isGroupQuit(this) ||
MessageTypes.isPollTerminate(this)
}
private fun Long.isIdentityVerifyType(): Boolean {
@@ -1522,7 +1577,8 @@ private fun ChatItem.validateChatItem(exportState: ExportState): ChatItem? {
this.paymentNotification == null &&
this.giftBadge == null &&
this.viewOnceMessage == null &&
this.directStoryReplyMessage == null
this.directStoryReplyMessage == null &&
this.poll == null
) {
Log.w(TAG, ExportSkips.emptyChatItem(this.dateSent))
return null
@@ -1693,7 +1749,8 @@ private data class ExtraMessageData(
val mentionsById: Map<Long, List<Mention>>,
val reactionsById: Map<Long, List<ReactionRecord>>,
val attachmentsById: Map<Long, List<DatabaseAttachment>>,
val groupReceiptsById: Map<Long, List<GroupReceiptTable.GroupReceiptInfo>>
val groupReceiptsById: Map<Long, List<GroupReceiptTable.GroupReceiptInfo>>,
val pollsById: Map<Long, PollRecord>
)
private enum class Direction {

View File

@@ -62,6 +62,7 @@ import org.thoughtcrime.securesms.database.model.databaseprotos.GV2UpdateDescrip
import org.thoughtcrime.securesms.database.model.databaseprotos.GiftBadge
import org.thoughtcrime.securesms.database.model.databaseprotos.MessageExtras
import org.thoughtcrime.securesms.database.model.databaseprotos.PaymentTombstone
import org.thoughtcrime.securesms.database.model.databaseprotos.PollTerminate
import org.thoughtcrime.securesms.database.model.databaseprotos.ProfileChangeDetails
import org.thoughtcrime.securesms.database.model.databaseprotos.SessionSwitchoverEvent
import org.thoughtcrime.securesms.database.model.databaseprotos.ThreadMergeEvent
@@ -72,6 +73,7 @@ import org.thoughtcrime.securesms.payments.Direction
import org.thoughtcrime.securesms.payments.FailureReason
import org.thoughtcrime.securesms.payments.State
import org.thoughtcrime.securesms.payments.proto.PaymentMetaData
import org.thoughtcrime.securesms.polls.Voter
import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.stickers.StickerLocator
@@ -304,6 +306,21 @@ class ChatItemArchiveImporter(
)
db.insert(CallTable.TABLE_NAME, SQLiteDatabase.CONFLICT_IGNORE, values)
}
} else if (this.updateMessage.pollTerminate != null) {
followUps += { endPollMessageId ->
val pollMessageId = SignalDatabase.messages.getMessageFor(updateMessage.pollTerminate.targetSentTimestamp, fromRecipientId)?.id ?: -1
val pollId = SignalDatabase.polls.getPollId(pollMessageId)
val messageExtras = MessageExtras(pollTerminate = PollTerminate(question = updateMessage.pollTerminate.question, messageId = pollMessageId, targetTimestamp = updateMessage.pollTerminate.targetSentTimestamp))
db.update(MessageTable.TABLE_NAME)
.values(MessageTable.MESSAGE_EXTRAS to messageExtras.encode())
.where("${MessageTable.ID} = ?", endPollMessageId)
.run()
if (pollId != null) {
SignalDatabase.polls.endPoll(pollId = pollId, endingMessageId = endPollMessageId)
}
}
}
}
@@ -459,6 +476,35 @@ class ChatItemArchiveImporter(
}
}
if (this.poll != null) {
contentValues.put(MessageTable.BODY, poll.question)
contentValues.put(MessageTable.VOTES_LAST_SEEN, System.currentTimeMillis())
followUps += { messageRowId ->
val pollId = SignalDatabase.polls.insertPoll(
question = poll.question,
allowMultipleVotes = poll.allowMultiple,
options = poll.options.map { it.option },
authorId = fromRecipientId.toLong(),
messageId = messageRowId
)
val localOptionIds = SignalDatabase.polls.getPollOptionIds(pollId)
poll.options.forEachIndexed { index, option ->
val localVoterIds = option.votes.map { importState.remoteToLocalRecipientId[it.voterId]?.toLong() }
val voteCounts = option.votes.map { it.voteCount }
val localVoters = localVoterIds.mapIndexedNotNull { index, id -> id?.let { Voter(id = id, voteCount = voteCounts[index]) } }
SignalDatabase.polls.addPollVotes(pollId = pollId, optionId = localOptionIds[index], voters = localVoters)
}
if (poll.hasEnded) {
// At this point, we don't know what message ended the poll. Instead, we set it to -1 to indicate that it
// is ended and will update endingMessageId when we process the poll terminate message (if it exists).
SignalDatabase.polls.endPoll(pollId = pollId, endingMessageId = -1)
}
}
}
val followUp: ((Long) -> Unit)? = if (followUps.isNotEmpty()) {
{ messageId ->
followUps.forEach { it(messageId) }
@@ -774,6 +820,9 @@ class ChatItemArchiveImporter(
val messageExtras = MessageExtras(profileChangeDetails = profileChangeDetails).encode()
put(MessageTable.MESSAGE_EXTRAS, messageExtras)
}
updateMessage.pollTerminate != null -> {
typeFlags = MessageTypes.SPECIAL_TYPE_POLL_TERMINATE or (getAsLong(MessageTable.TYPE) and MessageTypes.BASE_TYPE_MASK.inv())
}
updateMessage.sessionSwitchover != null -> {
typeFlags = MessageTypes.SESSION_SWITCHOVER_TYPE or (getAsLong(MessageTable.TYPE) and MessageTypes.BASE_TYPE_MASK.inv())
val sessionSwitchoverDetails = SessionSwitchoverEvent(e164 = updateMessage.sessionSwitchover.e164.toString()).encode()