From 525175f04a9d1b247769ad5e563ef8c92f8a84a9 Mon Sep 17 00:00:00 2001 From: Michelle Tang Date: Tue, 14 Oct 2025 11:41:47 -0400 Subject: [PATCH] Add polls to backups. --- .../securesms/database/PollTablesTest.kt | 5 +- .../DataMessageProcessorTest_polls.kt | 9 +- .../securesms/backup/v2/ArchiveErrorCases.kt | 4 + .../v2/exporters/ChatItemArchiveExporter.kt | 65 +++++++++++++- .../v2/importer/ChatItemArchiveImporter.kt | 49 +++++++++++ .../securesms/conversation/PollComponent.kt | 17 ++-- .../clicklisteners/PollVotesFragment.kt | 5 +- .../clicklisteners/PollVotesViewModel.kt | 2 +- .../securesms/database/PollTables.kt | 84 ++++++++++++------- .../securesms/database/ThreadBodyUtil.java | 2 +- .../database/model/MessageRecord.java | 2 +- .../securesms/polls/PollOption.kt | 2 +- .../org/thoughtcrime/securesms/polls/Voter.kt | 18 ++++ app/src/main/protowire/Backup.proto | 26 ++++++ 14 files changed, 236 insertions(+), 54 deletions(-) create mode 100644 app/src/main/java/org/thoughtcrime/securesms/polls/Voter.kt diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/PollTablesTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/PollTablesTest.kt index 85f7d2b870..dea0771dce 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/database/PollTablesTest.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/PollTablesTest.kt @@ -11,6 +11,7 @@ import org.thoughtcrime.securesms.database.model.MessageId import org.thoughtcrime.securesms.mms.IncomingMessage import org.thoughtcrime.securesms.polls.PollOption import org.thoughtcrime.securesms.polls.PollRecord +import org.thoughtcrime.securesms.polls.Voter import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.testing.SignalActivityRule @@ -28,7 +29,7 @@ class PollTablesTest { id = 1, question = "how do you feel about unit testing?", pollOptions = listOf( - PollOption(1, "yay", listOf(1)), + PollOption(1, "yay", listOf(Voter(1, 1))), PollOption(2, "ok", emptyList()), PollOption(3, "nay", emptyList()) ), @@ -79,7 +80,7 @@ class PollTablesTest { SignalDatabase.polls.insertVotes(pollId = 1, pollOptionIds = listOf(3), voterId = 1, voteCount = 2, messageId = MessageId(1)) SignalDatabase.polls.insertVotes(pollId = 1, pollOptionIds = listOf(1), voterId = 1, voteCount = 3, messageId = MessageId(1)) - assertEquals(poll1, SignalDatabase.polls.getPoll(1)) + assertEquals(listOf(Voter(1, 3)), SignalDatabase.polls.getPoll(1)!!.pollOptions[0].voters) } @Test diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/messages/DataMessageProcessorTest_polls.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/messages/DataMessageProcessorTest_polls.kt index ae1e68cd97..13d8adb0a3 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/messages/DataMessageProcessorTest_polls.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/messages/DataMessageProcessorTest_polls.kt @@ -16,6 +16,7 @@ import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.database.model.MessageId import org.thoughtcrime.securesms.groups.GroupId import org.thoughtcrime.securesms.mms.IncomingMessage +import org.thoughtcrime.securesms.polls.Voter import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.RecipientId import org.thoughtcrime.securesms.testing.GroupTestingUtils @@ -187,7 +188,7 @@ class DataMessageProcessorTest_polls { assertThat(messageId!!.id).isEqualTo(1) val poll = SignalDatabase.polls.getPoll(messageId.id) assert(poll != null) - assertThat(poll!!.pollOptions[0].voterIds).isEqualTo(listOf(bob.id.toLong())) + assertThat(poll!!.pollOptions[0].voters).isEqualTo(listOf(Voter(bob.id.toLong(), 1))) } @Test @@ -207,9 +208,9 @@ class DataMessageProcessorTest_polls { assert(messageId != null) val poll = SignalDatabase.polls.getPoll(messageId!!.id) assert(poll != null) - assertThat(poll!!.pollOptions[0].voterIds).isEqualTo(listOf(bob.id.toLong())) - assertThat(poll.pollOptions[1].voterIds).isEqualTo(listOf(bob.id.toLong())) - assertThat(poll.pollOptions[2].voterIds).isEqualTo(listOf(bob.id.toLong())) + assertThat(poll!!.pollOptions[0].voters).isEqualTo(listOf(Voter(bob.id.toLong(), 1))) + assertThat(poll.pollOptions[1].voters).isEqualTo(listOf(Voter(bob.id.toLong(), 1))) + assertThat(poll.pollOptions[2].voters).isEqualTo(listOf(Voter(bob.id.toLong(), 1))) } @Test diff --git a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/ArchiveErrorCases.kt b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/ArchiveErrorCases.kt index 1beb751e38..2fec5f27e4 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/ArchiveErrorCases.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/ArchiveErrorCases.kt @@ -119,6 +119,10 @@ object ExportSkips { return log(sentTimestamp, "Failed to parse thread merge event.") } + fun pollTerminateIsEmpty(sentTimestamp: Long): String { + return log(sentTimestamp, "Poll terminate update was empty.") + } + fun individualChatUpdateInWrongTypeOfChat(sentTimestamp: Long): String { return log(sentTimestamp, "A chat update that only makes sense for individual chats was found in a different kind of chat.") } diff --git a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/exporters/ChatItemArchiveExporter.kt b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/exporters/ChatItemArchiveExporter.kt index 24322329f3..2dd4940710 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/exporters/ChatItemArchiveExporter.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/exporters/ChatItemArchiveExporter.kt @@ -50,6 +50,8 @@ import org.thoughtcrime.securesms.backup.v2.proto.IndividualCall import org.thoughtcrime.securesms.backup.v2.proto.LearnedProfileChatUpdate import org.thoughtcrime.securesms.backup.v2.proto.MessageAttachment import org.thoughtcrime.securesms.backup.v2.proto.PaymentNotification +import org.thoughtcrime.securesms.backup.v2.proto.Poll +import org.thoughtcrime.securesms.backup.v2.proto.PollTerminateUpdate import org.thoughtcrime.securesms.backup.v2.proto.ProfileChangeChatUpdate import org.thoughtcrime.securesms.backup.v2.proto.Quote import org.thoughtcrime.securesms.backup.v2.proto.Reaction @@ -93,6 +95,7 @@ import org.thoughtcrime.securesms.mms.PartAuthority import org.thoughtcrime.securesms.mms.QuoteModel import org.thoughtcrime.securesms.payments.FailureReason import org.thoughtcrime.securesms.payments.State +import org.thoughtcrime.securesms.polls.PollRecord import org.thoughtcrime.securesms.recipients.RecipientId import org.thoughtcrime.securesms.util.JsonUtils import org.thoughtcrime.securesms.util.MediaUtil @@ -371,6 +374,22 @@ class ChatItemArchiveExporter( transformTimer.emit("story") } + MessageTypes.isPollTerminate(record.type) -> { + val pollTerminateUpdate = record.toRemotePollTerminateUpdate() + if (pollTerminateUpdate == null) { + Log.w(TAG, ExportSkips.pollTerminateIsEmpty(record.dateSent)) + continue + } + builder.updateMessage = ChatUpdateMessage(pollTerminate = pollTerminateUpdate) + transformTimer.emit("poll-terminate") + } + + extraData.pollsById[record.id] != null -> { + val poll = extraData.pollsById[record.id]!! + builder.poll = poll.toRemotePollMessage() + transformTimer.emit("poll") + } + else -> { val attachments = extraData.attachmentsById[record.id] val sticker = attachments?.firstOrNull { dbAttachment -> dbAttachment.isSticker } @@ -471,16 +490,24 @@ class ChatItemArchiveExporter( } } + val pollsFuture = executor.submitTyped { + extraDataTimer.timeEvent("polls") { + db.pollTable.getPollsForMessages(messageIds) + } + } + val mentionsResult = mentionsFuture.get() val reactionsResult = reactionsFuture.get() val attachmentsResult = attachmentsFuture.get() val groupReceiptsResult = groupReceiptsFuture.get() + val pollsResult = pollsFuture.get() return ExtraMessageData( mentionsById = mentionsResult, reactionsById = reactionsResult, attachmentsById = attachmentsResult, - groupReceiptsById = groupReceiptsResult + groupReceiptsById = groupReceiptsResult, + pollsById = pollsResult ) } } @@ -783,6 +810,14 @@ private fun BackupMessageRecord.toRemotePaymentNotificationUpdate(db: SignalData } } +private fun BackupMessageRecord.toRemotePollTerminateUpdate(): PollTerminateUpdate? { + val pollTerminate = this.messageExtras?.pollTerminate ?: return null + return PollTerminateUpdate( + targetSentTimestamp = pollTerminate.targetTimestamp, + question = pollTerminate.question + ) +} + private fun BackupMessageRecord.toRemoteSharedContact(attachments: List?): Contact? { if (this.sharedContacts.isNullOrEmpty()) { return null @@ -1131,6 +1166,25 @@ private fun BackupMessageRecord.toRemoteGiftBadgeUpdate(): BackupGiftBadge? { ) } +private fun PollRecord.toRemotePollMessage(): Poll { + return Poll( + question = this.question, + allowMultiple = this.allowMultipleVotes, + hasEnded = this.hasEnded, + options = this.pollOptions.map { option -> + Poll.PollOption( + option = option.text, + votes = option.voters.map { voter -> + Poll.PollOption.PollVote( + voterId = voter.id, + voteCount = voter.voteCount + ) + } + ) + } + ) +} + private fun DatabaseAttachment.toRemoteStickerMessage(sentTimestamp: Long, reactions: List?): StickerMessage? { val stickerLocator = this.stickerLocator!! @@ -1491,7 +1545,8 @@ private fun Long.isDirectionlessType(): Boolean { MessageTypes.isGroupCall(this) || MessageTypes.isGroupUpdate(this) || MessageTypes.isGroupV1MigrationEvent(this) || - MessageTypes.isGroupQuit(this) + MessageTypes.isGroupQuit(this) || + MessageTypes.isPollTerminate(this) } private fun Long.isIdentityVerifyType(): Boolean { @@ -1522,7 +1577,8 @@ private fun ChatItem.validateChatItem(exportState: ExportState): ChatItem? { this.paymentNotification == null && this.giftBadge == null && this.viewOnceMessage == null && - this.directStoryReplyMessage == null + this.directStoryReplyMessage == null && + this.poll == null ) { Log.w(TAG, ExportSkips.emptyChatItem(this.dateSent)) return null @@ -1693,7 +1749,8 @@ private data class ExtraMessageData( val mentionsById: Map>, val reactionsById: Map>, val attachmentsById: Map>, - val groupReceiptsById: Map> + val groupReceiptsById: Map>, + val pollsById: Map ) private enum class Direction { diff --git a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/importer/ChatItemArchiveImporter.kt b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/importer/ChatItemArchiveImporter.kt index db3b31a351..9cd2c097d0 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/importer/ChatItemArchiveImporter.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/importer/ChatItemArchiveImporter.kt @@ -62,6 +62,7 @@ import org.thoughtcrime.securesms.database.model.databaseprotos.GV2UpdateDescrip import org.thoughtcrime.securesms.database.model.databaseprotos.GiftBadge import org.thoughtcrime.securesms.database.model.databaseprotos.MessageExtras import org.thoughtcrime.securesms.database.model.databaseprotos.PaymentTombstone +import org.thoughtcrime.securesms.database.model.databaseprotos.PollTerminate import org.thoughtcrime.securesms.database.model.databaseprotos.ProfileChangeDetails import org.thoughtcrime.securesms.database.model.databaseprotos.SessionSwitchoverEvent import org.thoughtcrime.securesms.database.model.databaseprotos.ThreadMergeEvent @@ -72,6 +73,7 @@ import org.thoughtcrime.securesms.payments.Direction import org.thoughtcrime.securesms.payments.FailureReason import org.thoughtcrime.securesms.payments.State import org.thoughtcrime.securesms.payments.proto.PaymentMetaData +import org.thoughtcrime.securesms.polls.Voter import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.RecipientId import org.thoughtcrime.securesms.stickers.StickerLocator @@ -304,6 +306,21 @@ class ChatItemArchiveImporter( ) db.insert(CallTable.TABLE_NAME, SQLiteDatabase.CONFLICT_IGNORE, values) } + } else if (this.updateMessage.pollTerminate != null) { + followUps += { endPollMessageId -> + val pollMessageId = SignalDatabase.messages.getMessageFor(updateMessage.pollTerminate.targetSentTimestamp, fromRecipientId)?.id ?: -1 + val pollId = SignalDatabase.polls.getPollId(pollMessageId) + + val messageExtras = MessageExtras(pollTerminate = PollTerminate(question = updateMessage.pollTerminate.question, messageId = pollMessageId, targetTimestamp = updateMessage.pollTerminate.targetSentTimestamp)) + db.update(MessageTable.TABLE_NAME) + .values(MessageTable.MESSAGE_EXTRAS to messageExtras.encode()) + .where("${MessageTable.ID} = ?", endPollMessageId) + .run() + + if (pollId != null) { + SignalDatabase.polls.endPoll(pollId = pollId, endingMessageId = endPollMessageId) + } + } } } @@ -459,6 +476,35 @@ class ChatItemArchiveImporter( } } + if (this.poll != null) { + contentValues.put(MessageTable.BODY, poll.question) + contentValues.put(MessageTable.VOTES_LAST_SEEN, System.currentTimeMillis()) + + followUps += { messageRowId -> + val pollId = SignalDatabase.polls.insertPoll( + question = poll.question, + allowMultipleVotes = poll.allowMultiple, + options = poll.options.map { it.option }, + authorId = fromRecipientId.toLong(), + messageId = messageRowId + ) + + val localOptionIds = SignalDatabase.polls.getPollOptionIds(pollId) + poll.options.forEachIndexed { index, option -> + val localVoterIds = option.votes.map { importState.remoteToLocalRecipientId[it.voterId]?.toLong() } + val voteCounts = option.votes.map { it.voteCount } + val localVoters = localVoterIds.mapIndexedNotNull { index, id -> id?.let { Voter(id = id, voteCount = voteCounts[index]) } } + SignalDatabase.polls.addPollVotes(pollId = pollId, optionId = localOptionIds[index], voters = localVoters) + } + + if (poll.hasEnded) { + // At this point, we don't know what message ended the poll. Instead, we set it to -1 to indicate that it + // is ended and will update endingMessageId when we process the poll terminate message (if it exists). + SignalDatabase.polls.endPoll(pollId = pollId, endingMessageId = -1) + } + } + } + val followUp: ((Long) -> Unit)? = if (followUps.isNotEmpty()) { { messageId -> followUps.forEach { it(messageId) } @@ -774,6 +820,9 @@ class ChatItemArchiveImporter( val messageExtras = MessageExtras(profileChangeDetails = profileChangeDetails).encode() put(MessageTable.MESSAGE_EXTRAS, messageExtras) } + updateMessage.pollTerminate != null -> { + typeFlags = MessageTypes.SPECIAL_TYPE_POLL_TERMINATE or (getAsLong(MessageTable.TYPE) and MessageTypes.BASE_TYPE_MASK.inv()) + } updateMessage.sessionSwitchover != null -> { typeFlags = MessageTypes.SESSION_SWITCHOVER_TYPE or (getAsLong(MessageTable.TYPE) and MessageTypes.BASE_TYPE_MASK.inv()) val sessionSwitchoverDetails = SessionSwitchoverEvent(e164 = updateMessage.sessionSwitchover.e164.toString()).encode() diff --git a/app/src/main/java/org/thoughtcrime/securesms/conversation/PollComponent.kt b/app/src/main/java/org/thoughtcrime/securesms/conversation/PollComponent.kt index ddb0631773..f678dc65a3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/conversation/PollComponent.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/conversation/PollComponent.kt @@ -50,6 +50,7 @@ import org.thoughtcrime.securesms.R import org.thoughtcrime.securesms.components.compose.RoundCheckbox import org.thoughtcrime.securesms.polls.PollOption import org.thoughtcrime.securesms.polls.PollRecord +import org.thoughtcrime.securesms.polls.Voter import org.thoughtcrime.securesms.util.DynamicTheme import org.thoughtcrime.securesms.util.VibrateUtil @@ -85,7 +86,7 @@ private fun Poll( onToggleVote: (PollOption, Boolean) -> Unit = { _, _ -> }, pollColors: PollColors = PollColorsType.Incoming.getColors(-1) ) { - val totalVotes = remember(poll.pollOptions) { poll.pollOptions.sumOf { it.voterIds.size } } + val totalVotes = remember(poll.pollOptions) { poll.pollOptions.sumOf { it.voters.size } } val caption = when { poll.hasEnded -> R.string.Poll__final_results poll.allowMultipleVotes -> R.string.Poll__select_multiple @@ -139,8 +140,8 @@ private fun PollOption( ) { val context = LocalContext.current val haptics = LocalHapticFeedback.current - val progress = remember(option.voterIds.size, totalVotes) { - if (totalVotes > 0) (option.voterIds.size.toFloat() / totalVotes.toFloat()) else 0f + val progress = remember(option.voters.size, totalVotes) { + if (totalVotes > 0) (option.voters.size.toFloat() / totalVotes.toFloat()) else 0f } val progressValue by animateFloatAsState(targetValue = progress, animationSpec = tween(durationMillis = 250)) @@ -201,7 +202,7 @@ private fun PollOption( } AnimatedContent( - targetState = option.voterIds.size + targetState = option.voters.size ) { size -> Text( text = size.toString(), @@ -289,9 +290,9 @@ private fun PollPreview() { id = 1, question = "How do you feel about compose previews?", pollOptions = listOf( - PollOption(1, "yay", listOf(1), isSelected = true), - PollOption(2, "ok", listOf(1, 2)), - PollOption(3, "nay", listOf(2, 3, 4)) + PollOption(1, "yay", listOf(Voter(1, 1)), isSelected = true), + PollOption(2, "ok", listOf(Voter(1, 1), Voter(2, 1))), + PollOption(3, "nay", listOf(Voter(1, 1), Voter(2, 1), Voter(3, 1))) ), allowMultipleVotes = false, hasEnded = false, @@ -333,7 +334,7 @@ private fun FinishedPollPreview() { id = 1, question = "How do you feel about finished compose previews?", pollOptions = listOf( - PollOption(1, "yay", listOf(1)), + PollOption(1, "yay", listOf(Voter(1, 1))), PollOption(2, "ok", emptyList(), isSelected = true), PollOption(3, "nay", emptyList()) ), diff --git a/app/src/main/java/org/thoughtcrime/securesms/conversation/clicklisteners/PollVotesFragment.kt b/app/src/main/java/org/thoughtcrime/securesms/conversation/clicklisteners/PollVotesFragment.kt index a72eb00a62..c32562fd82 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/conversation/clicklisteners/PollVotesFragment.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/conversation/clicklisteners/PollVotesFragment.kt @@ -48,6 +48,7 @@ import org.thoughtcrime.securesms.compose.ComposeDialogFragment import org.thoughtcrime.securesms.conversation.clicklisteners.PollVotesFragment.Companion.MAX_INITIAL_VOTER_COUNT import org.thoughtcrime.securesms.polls.PollOption import org.thoughtcrime.securesms.polls.PollRecord +import org.thoughtcrime.securesms.polls.Voter import org.thoughtcrime.securesms.util.viewModel /** @@ -218,8 +219,8 @@ private fun PollResultsScreenPreview() { id = 1, question = "How do you feel about finished compose previews?", pollOptions = listOf( - PollOption(1, "Yay", listOf(1, 12, 3)), - PollOption(2, "Ok", listOf(2, 4), isSelected = true), + PollOption(1, "Yay", listOf(Voter(1, 1), Voter(12, 1), Voter(3, 1))), + PollOption(2, "Ok", listOf(Voter(2, 1), Voter(4, 1)), isSelected = true), PollOption(3, "Nay", emptyList()) ), allowMultipleVotes = false, diff --git a/app/src/main/java/org/thoughtcrime/securesms/conversation/clicklisteners/PollVotesViewModel.kt b/app/src/main/java/org/thoughtcrime/securesms/conversation/clicklisteners/PollVotesViewModel.kt index d2517c94f6..9853590d65 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/conversation/clicklisteners/PollVotesViewModel.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/conversation/clicklisteners/PollVotesViewModel.kt @@ -39,7 +39,7 @@ class PollVotesViewModel(pollId: Long) : ViewModel() { pollOptions = poll.pollOptions.map { option -> PollOptionModel( pollOption = option, - voters = Recipient.resolvedList(option.voterIds.map { voter -> RecipientId.from(voter) }) + voters = Recipient.resolvedList(option.voters.map { voter -> RecipientId.from(voter.id) }) ) }, isAuthor = poll.authorId == Recipient.self().id.toLong() diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/PollTables.kt b/app/src/main/java/org/thoughtcrime/securesms/database/PollTables.kt index a15cf3defb..b7c83fd442 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/PollTables.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/PollTables.kt @@ -30,6 +30,7 @@ import org.thoughtcrime.securesms.polls.Poll import org.thoughtcrime.securesms.polls.PollOption import org.thoughtcrime.securesms.polls.PollRecord import org.thoughtcrime.securesms.polls.PollVote +import org.thoughtcrime.securesms.polls.Voter import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.RecipientId @@ -171,10 +172,10 @@ class PollTables(context: Context?, databaseHelper: SignalDatabase?) : DatabaseT } /** - * Inserts a newly created poll with its options + * Inserts a newly created poll with its options. Returns the newly created row id */ - fun insertPoll(question: String, allowMultipleVotes: Boolean, options: List, authorId: Long, messageId: Long) { - writableDatabase.withinTransaction { db -> + fun insertPoll(question: String, allowMultipleVotes: Boolean, options: List, authorId: Long, messageId: Long): Long { + return writableDatabase.withinTransaction { db -> val pollId = db.insertInto(PollTable.TABLE_NAME) .values( contentValuesOf( @@ -193,6 +194,30 @@ class PollTables(context: Context?, databaseHelper: SignalDatabase?) : DatabaseT ).forEach { db.execSQL(it.where, it.whereArgs) } + pollId + } + } + + /** + * Inserts a poll option and voters for that option. Called when restoring polls from backups. + */ + fun addPollVotes(pollId: Long, optionId: Long, voters: List) { + writableDatabase.withinTransaction { db -> + SqlUtil.buildBulkInsert( + PollVoteTable.TABLE_NAME, + arrayOf(PollVoteTable.POLL_ID, PollVoteTable.POLL_OPTION_ID, PollVoteTable.VOTER_ID, PollVoteTable.VOTE_COUNT, PollVoteTable.DATE_RECEIVED, PollVoteTable.VOTE_STATE), + voters.map { voter -> + contentValuesOf( + PollVoteTable.POLL_ID to pollId, + PollVoteTable.POLL_OPTION_ID to optionId, + PollVoteTable.VOTER_ID to voter.id, + PollVoteTable.VOTE_COUNT to voter.voteCount, + PollVoteTable.VOTE_STATE to VoteState.ADDED.value + ) + } + ).forEach { + db.execSQL(it.where, it.whereArgs) + } } } @@ -504,31 +529,30 @@ class PollTables(context: Context?, databaseHelper: SignalDatabase?) : DatabaseT val self = Recipient.self().id.toLong() val query = SqlUtil.buildFastCollectionQuery(PollTable.MESSAGE_ID, messageIds) - return readableDatabase.withinTransaction { db -> - db.select(PollTable.ID, PollTable.MESSAGE_ID, PollTable.QUESTION, PollTable.ALLOW_MULTIPLE_VOTES, PollTable.END_MESSAGE_ID, PollTable.AUTHOR_ID, PollTable.MESSAGE_ID) - .from(PollTable.TABLE_NAME) - .where(query.where, query.whereArgs) - .run() - .readToMap { cursor -> - val pollId = cursor.requireLong(PollTable.ID) - val pollVotes = getPollVotes(pollId) - val pendingVotes = getPendingVotes(pollId) - val pollOptions = getPollOptions(pollId).map { option -> - val voterIds = pollVotes[option.key] ?: emptyList() - PollOption(id = option.key, text = option.value, voterIds = voterIds, isSelected = voterIds.contains(self), isPending = pendingVotes.contains(option.key)) - } - val poll = PollRecord( - id = pollId, - question = cursor.requireNonNullString(PollTable.QUESTION), - pollOptions = pollOptions, - allowMultipleVotes = cursor.requireBoolean(PollTable.ALLOW_MULTIPLE_VOTES), - hasEnded = cursor.requireBoolean(PollTable.END_MESSAGE_ID), - authorId = cursor.requireLong(PollTable.AUTHOR_ID), - messageId = cursor.requireLong(PollTable.MESSAGE_ID) - ) - cursor.requireLong(PollTable.MESSAGE_ID) to poll + return readableDatabase + .select(PollTable.ID, PollTable.MESSAGE_ID, PollTable.QUESTION, PollTable.ALLOW_MULTIPLE_VOTES, PollTable.END_MESSAGE_ID, PollTable.AUTHOR_ID, PollTable.MESSAGE_ID) + .from(PollTable.TABLE_NAME) + .where(query.where, query.whereArgs) + .run() + .readToMap { cursor -> + val pollId = cursor.requireLong(PollTable.ID) + val pollVotes = getPollVotes(pollId) + val pendingVotes = getPendingVotes(pollId) + val pollOptions = getPollOptions(pollId).map { option -> + val voters = pollVotes[option.key] ?: emptyList() + PollOption(id = option.key, text = option.value, voters = voters, isSelected = voters.any { it.id == self }, isPending = pendingVotes.contains(option.key)) } - } + val poll = PollRecord( + id = pollId, + question = cursor.requireNonNullString(PollTable.QUESTION), + pollOptions = pollOptions, + allowMultipleVotes = cursor.requireBoolean(PollTable.ALLOW_MULTIPLE_VOTES), + hasEnded = cursor.requireBoolean(PollTable.END_MESSAGE_ID), + authorId = cursor.requireLong(PollTable.AUTHOR_ID), + messageId = cursor.requireLong(PollTable.MESSAGE_ID) + ) + cursor.requireLong(PollTable.MESSAGE_ID) to poll + } } /** @@ -593,14 +617,14 @@ class PollTables(context: Context?, databaseHelper: SignalDatabase?) : DatabaseT } } - private fun getPollVotes(pollId: Long): Map> { + private fun getPollVotes(pollId: Long): Map> { return readableDatabase - .select(PollVoteTable.POLL_OPTION_ID, PollVoteTable.VOTER_ID) + .select(PollVoteTable.POLL_OPTION_ID, PollVoteTable.VOTER_ID, PollVoteTable.VOTE_COUNT) .from(PollVoteTable.TABLE_NAME) .where("${PollVoteTable.POLL_ID} = ? AND (${PollVoteTable.VOTE_STATE} = ${VoteState.ADDED.value} OR ${PollVoteTable.VOTE_STATE} = ${VoteState.PENDING_REMOVE.value})", pollId) .run() .groupBy { cursor -> - cursor.requireLong(PollVoteTable.POLL_OPTION_ID) to cursor.requireLong(PollVoteTable.VOTER_ID) + cursor.requireLong(PollVoteTable.POLL_OPTION_ID) to Voter(id = cursor.requireLong(PollVoteTable.VOTER_ID), voteCount = cursor.requireInt(PollVoteTable.VOTE_COUNT)) } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/ThreadBodyUtil.java b/app/src/main/java/org/thoughtcrime/securesms/database/ThreadBodyUtil.java index 7401f9e334..0492a45a14 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/ThreadBodyUtil.java +++ b/app/src/main/java/org/thoughtcrime/securesms/database/ThreadBodyUtil.java @@ -74,7 +74,7 @@ public final class ThreadBodyUtil { } else if (MessageRecordUtil.hasPoll(record)) { return new ThreadBody(context.getString(R.string.Poll__poll_question, record.getPoll().getQuestion())); } else if (MessageRecordUtil.hasPollTerminate(record)) { - String creator = record.isOutgoing() ? context.getResources().getString(R.string.MessageRecord_you) : record.getFromRecipient().getDisplayName(context); + String creator = record.getFromRecipient().isSelf() ? context.getResources().getString(R.string.MessageRecord_you) : record.getFromRecipient().getDisplayName(context); return new ThreadBody(context.getString(R.string.Poll__poll_end, creator, record.getMessageExtras().pollTerminate.question)); } diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/model/MessageRecord.java b/app/src/main/java/org/thoughtcrime/securesms/database/model/MessageRecord.java index dfe5f2a7d3..306ce4cbe0 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/model/MessageRecord.java +++ b/app/src/main/java/org/thoughtcrime/securesms/database/model/MessageRecord.java @@ -295,7 +295,7 @@ public abstract class MessageRecord extends DisplayRecord { } else if (isUnsupported()) { return staticUpdateDescription(context.getString(R.string.MessageRecord_unsupported_feature, getFromRecipient().getDisplayName(context)), Glyph.ERROR); } else if (MessageRecordUtil.hasPollTerminate(this)) { - String creator = isOutgoing() ? context.getString(R.string.MessageRecord_you) : getFromRecipient().getDisplayName(context); + String creator = getFromRecipient().isSelf() ? context.getString(R.string.MessageRecord_you) : getFromRecipient().getDisplayName(context); return staticUpdateDescriptionWithExpiration(context.getString(R.string.MessageRecord_ended_the_poll, creator, messageExtras.pollTerminate.question), Glyph.POLL); } diff --git a/app/src/main/java/org/thoughtcrime/securesms/polls/PollOption.kt b/app/src/main/java/org/thoughtcrime/securesms/polls/PollOption.kt index 880ca25549..a238d74ad2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/polls/PollOption.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/polls/PollOption.kt @@ -10,7 +10,7 @@ import kotlinx.parcelize.Parcelize data class PollOption( val id: Long, val text: String, - val voterIds: List, + val voters: List, val isSelected: Boolean = false, val isPending: Boolean = false ) : Parcelable diff --git a/app/src/main/java/org/thoughtcrime/securesms/polls/Voter.kt b/app/src/main/java/org/thoughtcrime/securesms/polls/Voter.kt new file mode 100644 index 0000000000..8cb27dba73 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/polls/Voter.kt @@ -0,0 +1,18 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.polls + +import android.os.Parcelable +import kotlinx.parcelize.Parcelize + +/** + * Class to track someone who has voted in an option within a poll. + */ +@Parcelize +data class Voter( + val id: Long, + val voteCount: Int +) : Parcelable diff --git a/app/src/main/protowire/Backup.proto b/app/src/main/protowire/Backup.proto index 752d3be5ab..1c88c6a5f5 100644 --- a/app/src/main/protowire/Backup.proto +++ b/app/src/main/protowire/Backup.proto @@ -428,6 +428,7 @@ message ChatItem { GiftBadge giftBadge = 17; ViewOnceMessage viewOnceMessage = 18; DirectStoryReplyMessage directStoryReplyMessage = 19; // group story reply messages are not backed up + Poll poll = 20; } } @@ -805,6 +806,25 @@ message Reaction { uint64 sortOrder = 4; } +message Poll { + + message PollOption { + + message PollVote { + uint64 voterId = 1; // A direct reference to Recipient proto id. Must be self or contact. + uint32 voteCount = 2; // Tracks how many times you voted. + } + + string option = 1; // Between 1-100 characters + repeated PollVote votes = 2; + } + + string question = 1; // Between 1-100 characters + bool allowMultiple = 2; + repeated PollOption options = 3; // At least two + bool hasEnded = 4; +} + message ChatUpdateMessage { // If unset, importers should ignore the update message without throwing an error. oneof update { @@ -817,6 +837,7 @@ message ChatUpdateMessage { IndividualCall individualCall = 7; GroupCall groupCall = 8; LearnedProfileChatUpdate learnedProfileChange = 9; + PollTerminateUpdate pollTerminate = 10; } } @@ -1182,6 +1203,11 @@ message GroupExpirationTimerUpdate { optional bytes updaterAci = 2; } +message PollTerminateUpdate { + uint64 targetSentTimestamp = 1; + string question = 2; // Between 1-100 characters +} + message StickerPack { bytes packId = 1; bytes packKey = 2;