diff --git a/ts/sql/Interface.std.ts b/ts/sql/Interface.std.ts index ce22cc7f07..02db9cf751 100644 --- a/ts/sql/Interface.std.ts +++ b/ts/sql/Interface.std.ts @@ -869,7 +869,8 @@ type ReadableInterface = { getMessageByAuthorAciAndSentAt: ( authorAci: AciString, - sentAtTimestamp: number + sentAtTimestamp: number, + options: { includeEdits: boolean } ) => MessageType | null; getMessageBySender: (options: { source?: string; diff --git a/ts/sql/Server.node.ts b/ts/sql/Server.node.ts index 83545f841b..0046f80032 100644 --- a/ts/sql/Server.node.ts +++ b/ts/sql/Server.node.ts @@ -49,7 +49,7 @@ import { isNormalNumber } from '../util/isNormalNumber.std.js'; import { isNotNil } from '../util/isNotNil.std.js'; import { parseIntOrThrow } from '../util/parseIntOrThrow.std.js'; import { updateSchema } from './migrations/index.node.js'; -import type { JSONRows, QueryFragment, QueryTemplate } from './util.std.js'; +import type { JSONRows, QueryTemplate, QueryFragment } from './util.std.js'; import { batchMultiVarQuery, bulkAdd, @@ -3322,17 +3322,30 @@ function getAllMessageIds(db: ReadableDB): Array { function getMessageByAuthorAciAndSentAt( db: ReadableDB, authorAci: AciString, - sentAtTimestamp: number + sentAtTimestamp: number, + options: { includeEdits: boolean } ): MessageType | null { return db.transaction(() => { - const [query, params] = sql` + const editedMessagesQuery = sqlFragment` + SELECT ${MESSAGE_COLUMNS_SELECT} + FROM edited_messages + INNER JOIN messages ON + messages.id = edited_messages.messageId + WHERE messages.sourceServiceId = ${authorAci} + AND edited_messages.sentAt = ${sentAtTimestamp} + `; + + const messagesQuery = sqlFragment` SELECT ${MESSAGE_COLUMNS_SELECT} FROM messages - WHERE sourceServiceId = ${authorAci} - AND sent_at = ${sentAtTimestamp} - LIMIT 2; + WHERE messages.sourceServiceId = ${authorAci} + AND messages.sent_at = ${sentAtTimestamp} `; + const [query, params] = options.includeEdits + ? sql`${editedMessagesQuery} UNION ${messagesQuery} LIMIT 2;` + : sql`${messagesQuery} LIMIT 2;`; + const rows = db.prepare(query).all(params); if (rows.length > 1) { diff --git a/ts/state/ducks/composer.preload.ts b/ts/state/ducks/composer.preload.ts index 20292c1e87..e75bb59acd 100644 --- a/ts/state/ducks/composer.preload.ts +++ b/ts/state/ducks/composer.preload.ts @@ -392,7 +392,8 @@ function scrollToPinnedMessage( return async (dispatch, getState) => { const pinnedMessage = await DataReader.getMessageByAuthorAciAndSentAt( pinMessage.targetAuthorAci, - pinMessage.targetSentTimestamp + pinMessage.targetSentTimestamp, + { includeEdits: true } ); if (!pinnedMessage) { diff --git a/ts/util/getPinMessageTarget.preload.ts b/ts/util/getPinMessageTarget.preload.ts index d8a3eb6e35..8e8151ef7b 100644 --- a/ts/util/getPinMessageTarget.preload.ts +++ b/ts/util/getPinMessageTarget.preload.ts @@ -1,14 +1,18 @@ // Copyright 2026 Signal Messenger, LLC // SPDX-License-Identifier: AGPL-3.0-only +import { createLogger } from '../logging/log.std.js'; import { isIncoming } from '../messages/helpers.std.js'; import type { ReadonlyMessageAttributesType } from '../model-types.js'; import { DataReader } from '../sql/Client.preload.js'; import { itemStorage } from '../textsecure/Storage.preload.js'; import type { AciString } from '../types/ServiceId.std.js'; import { strictAssert } from './assert.std.js'; +import { getMessageSentTimestamp } from './getMessageSentTimestamp.std.js'; import { isAciString } from './isAciString.std.js'; +const log = createLogger('getPinMessageTarget'); + export type PinnedMessageTarget = Readonly<{ conversationId: string; targetMessageId: string; @@ -40,6 +44,9 @@ export async function getPinnedMessageTarget( conversationId: message.conversationId, targetMessageId: message.id, targetAuthorAci: getMessageAuthorAci(message), - targetSentTimestamp: message.sent_at, + targetSentTimestamp: getMessageSentTimestamp(message, { + includeEdits: true, + log, + }), }; }