diff --git a/src/vs/platform/agentHost/common/agentService.ts b/src/vs/platform/agentHost/common/agentService.ts index baf51f3810f..c60161cf89c 100644 --- a/src/vs/platform/agentHost/common/agentService.ts +++ b/src/vs/platform/agentHost/common/agentService.ts @@ -101,6 +101,8 @@ export interface IAgentCreateSessionConfig { readonly model?: string; readonly session?: URI; readonly workingDirectory?: URI; + /** Fork from an existing session at a specific turn index. */ + readonly fork?: { readonly session: URI; readonly turnIndex: number }; } /** Serializable attachment passed alongside a message to the agent host. */ @@ -364,6 +366,23 @@ export interface IAgent { */ authenticate(resource: string, token: string): Promise; + /** + * Truncate a session's history. If `turnIndex` is provided (0-based), keeps + * turns up to and including that turn. If omitted, all turns are removed. + * Optional — not all providers support truncation. + */ + truncateSession?(session: URI, turnIndex?: number): Promise; + + /** + * Fork a session at a specific turn, creating a new session on disk + * with the source session's history up to and including the specified turn. + * Optional — not all providers support forking. + * + * @param turnIndex 0-based turn index to fork at. + * @returns The new session's raw ID. + */ + forkSession?(sourceSession: URI, newSessionId: string, turnIndex: number): Promise; + /** Gracefully shut down all sessions. */ shutdown(): Promise; diff --git a/src/vs/platform/agentHost/common/state/protocol/action-origin.generated.ts b/src/vs/platform/agentHost/common/state/protocol/action-origin.generated.ts index f5b81d88e23..ab09368f19f 100644 --- a/src/vs/platform/agentHost/common/state/protocol/action-origin.generated.ts +++ b/src/vs/platform/agentHost/common/state/protocol/action-origin.generated.ts @@ -9,7 +9,7 @@ // Generated from types/actions.ts — do not edit // Run `npm run generate` to regenerate. -import { ActionType, type IStateAction, type IRootAgentsChangedAction, type IRootActiveSessionsChangedAction, type ISessionReadyAction, type ISessionCreationFailedAction, type ISessionTurnStartedAction, type ISessionDeltaAction, type ISessionResponsePartAction, type ISessionToolCallStartAction, type ISessionToolCallDeltaAction, type ISessionToolCallReadyAction, type ISessionToolCallConfirmedAction, type ISessionToolCallCompleteAction, type ISessionToolCallResultConfirmedAction, type ISessionTurnCompleteAction, type ISessionTurnCancelledAction, type ISessionErrorAction, type ISessionTitleChangedAction, type ISessionUsageAction, type ISessionReasoningAction, type ISessionModelChangedAction, type ISessionServerToolsChangedAction, type ISessionActiveClientChangedAction, type ISessionActiveClientToolsChangedAction, type ISessionPendingMessageSetAction, type ISessionPendingMessageRemovedAction, type ISessionQueuedMessagesReorderedAction, type ISessionCustomizationsChangedAction, type ISessionCustomizationToggledAction } from './actions.js'; +import { ActionType, type IStateAction, type IRootAgentsChangedAction, type IRootActiveSessionsChangedAction, type ISessionReadyAction, type ISessionCreationFailedAction, type ISessionTurnStartedAction, type ISessionDeltaAction, type ISessionResponsePartAction, type ISessionToolCallStartAction, type ISessionToolCallDeltaAction, type ISessionToolCallReadyAction, type ISessionToolCallConfirmedAction, type ISessionToolCallCompleteAction, type ISessionToolCallResultConfirmedAction, type ISessionTurnCompleteAction, type ISessionTurnCancelledAction, type ISessionErrorAction, type ISessionTitleChangedAction, type ISessionUsageAction, type ISessionReasoningAction, type ISessionModelChangedAction, type ISessionServerToolsChangedAction, type ISessionActiveClientChangedAction, type ISessionActiveClientToolsChangedAction, type ISessionPendingMessageSetAction, type ISessionPendingMessageRemovedAction, type ISessionQueuedMessagesReorderedAction, type ISessionCustomizationsChangedAction, type ISessionCustomizationToggledAction, type ISessionTruncatedAction } from './actions.js'; // ─── Root vs Session Action Unions ─────────────────────────────────────────── @@ -48,6 +48,7 @@ export type ISessionAction = | ISessionQueuedMessagesReorderedAction | ISessionCustomizationsChangedAction | ISessionCustomizationToggledAction + | ISessionTruncatedAction ; /** Union of session actions that clients may dispatch. */ @@ -65,6 +66,7 @@ export type IClientSessionAction = | ISessionPendingMessageRemovedAction | ISessionQueuedMessagesReorderedAction | ISessionCustomizationToggledAction + | ISessionTruncatedAction ; /** Union of session actions that only the server may produce. */ @@ -119,4 +121,5 @@ export const IS_CLIENT_DISPATCHABLE: { readonly [K in IStateAction['type']]: boo [ActionType.SessionQueuedMessagesReordered]: true, [ActionType.SessionCustomizationsChanged]: false, [ActionType.SessionCustomizationToggled]: true, + [ActionType.SessionTruncated]: true, }; diff --git a/src/vs/platform/agentHost/common/state/protocol/actions.ts b/src/vs/platform/agentHost/common/state/protocol/actions.ts index 9c454210617..9a9b6cfb871 100644 --- a/src/vs/platform/agentHost/common/state/protocol/actions.ts +++ b/src/vs/platform/agentHost/common/state/protocol/actions.ts @@ -45,6 +45,7 @@ export const enum ActionType { SessionQueuedMessagesReordered = 'session/queuedMessagesReordered', SessionCustomizationsChanged = 'session/customizationsChanged', SessionCustomizationToggled = 'session/customizationToggled', + SessionTruncated = 'session/truncated', } // ─── Action Envelope ───────────────────────────────────────────────────────── @@ -562,6 +563,31 @@ export interface ISessionCustomizationToggledAction { enabled: boolean; } +// ─── Truncation ────────────────────────────────────────────────────────────── + +/** + * Truncates a session's history. If `turnId` is provided, all turns after that + * turn are removed and the specified turn is kept. If `turnId` is omitted, all + * turns are removed. + * + * If there is an active turn it is silently dropped and the session status + * returns to `idle`. + * + * Common use-case: truncate old data then dispatch a new + * `session/turnStarted` with an edited message. + * + * @category Session Actions + * @version 1 + * @clientDispatchable + */ +export interface ISessionTruncatedAction { + type: ActionType.SessionTruncated; + /** Session URI */ + session: URI; + /** Keep turns up to and including this turn. Omit to clear all turns. */ + turnId?: string; +} + // ─── Pending Message Actions ───────────────────────────────────────────────── /** @@ -664,4 +690,5 @@ export type IStateAction = | ISessionPendingMessageRemovedAction | ISessionQueuedMessagesReorderedAction | ISessionCustomizationsChangedAction - | ISessionCustomizationToggledAction; + | ISessionCustomizationToggledAction + | ISessionTruncatedAction; diff --git a/src/vs/platform/agentHost/common/state/protocol/commands.ts b/src/vs/platform/agentHost/common/state/protocol/commands.ts index 96ea242ae2e..889d4a883b8 100644 --- a/src/vs/platform/agentHost/common/state/protocol/commands.ts +++ b/src/vs/platform/agentHost/common/state/protocol/commands.ts @@ -163,6 +163,20 @@ export interface ISubscribeResult { * { "jsonrpc": "2.0", "id": 2, "error": { "code": -32003, "message": "Session already exists" } } * ``` */ +/** + * Identifies a source session and turn to fork from. + * + * When provided in `createSession`, the server populates the new session with + * content from the source session up to and including the response of the + * specified turn. + */ +export interface ISessionForkSource { + /** URI of the existing session to fork from */ + session: URI; + /** Turn ID in the source session; content up to and including this turn's response is copied */ + turnId: string; +} + export interface ICreateSessionParams { /** Session URI (client-chosen, e.g. `copilot:/`) */ session: URI; @@ -172,6 +186,11 @@ export interface ICreateSessionParams { model?: string; /** Working directory for the session */ workingDirectory?: URI; + /** + * Fork from an existing session. The new session is populated with content + * from the source session up to and including the specified turn's response. + */ + fork?: ISessionForkSource; } // ─── disposeSession ────────────────────────────────────────────────────────── diff --git a/src/vs/platform/agentHost/common/state/protocol/reducers.ts b/src/vs/platform/agentHost/common/state/protocol/reducers.ts index 4c830b93b95..128074285b2 100644 --- a/src/vs/platform/agentHost/common/state/protocol/reducers.ts +++ b/src/vs/platform/agentHost/common/state/protocol/reducers.ts @@ -485,6 +485,27 @@ export function sessionReducer(state: ISessionState, action: ISessionAction, log return { ...state, customizations: updated }; } + // ── Truncation ──────────────────────────────────────────────────────── + + case ActionType.SessionTruncated: { + let turns: typeof state.turns; + if (action.turnId === undefined) { + turns = []; + } else { + const idx = state.turns.findIndex(t => t.id === action.turnId); + if (idx < 0) { + return state; + } + turns = state.turns.slice(0, idx + 1); + } + return { + ...state, + turns, + activeTurn: undefined, + summary: { ...state.summary, status: SessionStatus.Idle, modifiedAt: Date.now() }, + }; + } + // ── Pending Messages ────────────────────────────────────────────────── case ActionType.SessionPendingMessageSet: { diff --git a/src/vs/platform/agentHost/common/state/protocol/version/registry.ts b/src/vs/platform/agentHost/common/state/protocol/version/registry.ts index c3d45ae3dfa..11d2bb4b017 100644 --- a/src/vs/platform/agentHost/common/state/protocol/version/registry.ts +++ b/src/vs/platform/agentHost/common/state/protocol/version/registry.ts @@ -52,6 +52,7 @@ export const ACTION_INTRODUCED_IN: { readonly [K in IStateAction['type']]: numbe [ActionType.SessionQueuedMessagesReordered]: 1, [ActionType.SessionCustomizationsChanged]: 1, [ActionType.SessionCustomizationToggled]: 1, + [ActionType.SessionTruncated]: 1, }; /** diff --git a/src/vs/platform/agentHost/node/agentService.ts b/src/vs/platform/agentHost/node/agentService.ts index a5c017323db..134e2f1f93e 100644 --- a/src/vs/platform/agentHost/node/agentService.ts +++ b/src/vs/platform/agentHost/node/agentService.ts @@ -173,17 +173,41 @@ export class AgentService extends Disposable implements IAgentService { this._sessionToProvider.set(session.toString(), provider.id); this._logService.trace(`[AgentService] createSession returned: ${session.toString()}`); - // Create state in the state manager - const summary: ISessionSummary = { - resource: session.toString(), - provider: provider.id, - title: 'New Session', - status: SessionStatus.Idle, - createdAt: Date.now(), - modifiedAt: Date.now(), - workingDirectory: config?.workingDirectory?.toString(), - }; - this._stateManager.createSession(summary); + // When forking, populate the new session's protocol state with + // the source session's turns so the client sees the forked history. + if (config?.fork) { + const sourceState = this._stateManager.getSessionState(config.fork.session.toString()); + let sourceTurns: ITurn[] = []; + if (sourceState) { + const forkIdx = sourceState.turns.findIndex(t => t.id === config.fork!.turnId); + if (forkIdx >= 0) { + sourceTurns = sourceState.turns.slice(0, forkIdx + 1); + } + } + + const summary: ISessionSummary = { + resource: session.toString(), + provider: provider.id, + title: sourceState?.summary.title ?? 'Forked Session', + status: SessionStatus.Idle, + createdAt: Date.now(), + modifiedAt: Date.now(), + workingDirectory: config.workingDirectory?.toString(), + }; + this._stateManager.restoreSession(summary, sourceTurns); + } else { + // Create empty state for new sessions + const summary: ISessionSummary = { + resource: session.toString(), + provider: provider.id, + title: 'New Session', + status: SessionStatus.Idle, + createdAt: Date.now(), + modifiedAt: Date.now(), + workingDirectory: config?.workingDirectory?.toString(), + }; + this._stateManager.createSession(summary); + } this._stateManager.dispatchServerAction({ type: ActionType.SessionReady, session: session.toString() }); return session; diff --git a/src/vs/platform/agentHost/node/agentSideEffects.ts b/src/vs/platform/agentHost/node/agentSideEffects.ts index ede6b69acd9..cd497855e99 100644 --- a/src/vs/platform/agentHost/node/agentSideEffects.ts +++ b/src/vs/platform/agentHost/node/agentSideEffects.ts @@ -252,6 +252,24 @@ export class AgentSideEffects extends Disposable { this._syncPendingMessages(action.session); break; } + case ActionType.SessionTruncated: { + const agent = this._options.getAgent(action.session); + // Resolve the protocol turnId to a 0-based index using the + // state manager's turn list. The reducer has already applied + // the truncation, so we look at the pre-truncation state via + // the turnId position. + let turnIndex: number | undefined; + if (action.turnId !== undefined) { + const state = this._stateManager.getSessionState(action.session); + // After the reducer, the turns array is already truncated. + // The kept turns include the target, so its index = length - 1. + turnIndex = state ? state.turns.length - 1 : undefined; + } + agent?.truncateSession?.(URI.parse(action.session), turnIndex).catch(err => { + this._logService.error('[AgentSideEffects] truncateSession failed', err); + }); + break; + } } } diff --git a/src/vs/platform/agentHost/node/copilot/copilotAgent.ts b/src/vs/platform/agentHost/node/copilot/copilotAgent.ts index bf4ea118f95..22bf7becbc1 100644 --- a/src/vs/platform/agentHost/node/copilot/copilotAgent.ts +++ b/src/vs/platform/agentHost/node/copilot/copilotAgent.ts @@ -5,6 +5,7 @@ import { CopilotClient } from '@github/copilot-sdk'; import { rgPath } from '@vscode/ripgrep'; +import { SequencerByKey } from '../../../../base/common/async.js'; import { Emitter } from '../../../../base/common/event.js'; import { Disposable, DisposableMap } from '../../../../base/common/lifecycle.js'; import { FileAccess } from '../../../../base/common/network.js'; @@ -18,6 +19,7 @@ import { AgentSession, IAgent, IAgentAttachment, IAgentCreateSessionConfig, IAge import { type IPendingMessage, type PolicyState } from '../../common/state/sessionState.js'; import { CopilotAgentSession, SessionWrapperFactory } from './copilotAgentSession.js'; import { CopilotSessionWrapper } from './copilotSessionWrapper.js'; +import { forkCopilotSessionOnDisk, getCopilotDataDir, truncateCopilotSessionOnDisk } from './copilotAgentForking.js'; /** * Agent provider backed by the Copilot SDK {@link CopilotClient}. @@ -32,6 +34,7 @@ export class CopilotAgent extends Disposable implements IAgent { private _clientStarting: Promise | undefined; private _githubToken: string | undefined; private readonly _sessions = this._register(new DisposableMap()); + private readonly _sessionSequencer = new SequencerByKey(); constructor( @ILogService private readonly _logService: ILogService, @@ -181,10 +184,38 @@ export class CopilotAgent extends Disposable implements IAgent { this._logService.info(`[Copilot] Creating session... ${config?.model ? `model=${config.model}` : ''}`); const client = await this._ensureClient(); + // When forking, we manipulate the CLI's on-disk data and then resume + // instead of creating a fresh session via the SDK. + if (config?.fork) { + const sourceSessionId = AgentSession.id(config.fork.session); + const newSessionId = config.session ? AgentSession.id(config.session) : generateUuid(); + + // Serialize against the source session to prevent concurrent + // modifications while we read its on-disk data. + return this._sessionSequencer.queue(sourceSessionId, async () => { + this._logService.info(`[Copilot] Forking session ${sourceSessionId} at index ${config.fork!.turnIndex} → ${newSessionId}`); + + // Ensure the source session is loaded so on-disk data is available + if (!this._sessions.has(sourceSessionId)) { + await this._resumeSession(sourceSessionId); + } + + const copilotDataDir = getCopilotDataDir(); + await forkCopilotSessionOnDisk(copilotDataDir, sourceSessionId, newSessionId, config.fork!.turnIndex); + + // Resume the forked session so the SDK loads the forked history + const agentSession = await this._resumeSession(newSessionId); + const session = agentSession.sessionUri; + this._logService.info(`[Copilot] Forked session created: ${session.toString()}`); + return session; + }); + } + + const sessionId = config?.session ? AgentSession.id(config.session) : generateUuid(); const factory: SessionWrapperFactory = async callbacks => { const raw = await client.createSession({ model: config?.model, - sessionId: config?.session ? AgentSession.id(config.session) : undefined, + sessionId, streaming: true, workingDirectory: config?.workingDirectory?.fsPath, onPermissionRequest: callbacks.onPermissionRequest, @@ -193,7 +224,7 @@ export class CopilotAgent extends Disposable implements IAgent { return new CopilotSessionWrapper(raw); }; - const agentSession = this._createAgentSession(factory, config?.workingDirectory, config?.session ? AgentSession.id(config.session) : undefined); + const agentSession = this._createAgentSession(factory, config?.workingDirectory, sessionId); await agentSession.initializeSession(); const session = agentSession.sessionUri; @@ -203,8 +234,10 @@ export class CopilotAgent extends Disposable implements IAgent { async sendMessage(session: URI, prompt: string, attachments?: IAgentAttachment[]): Promise { const sessionId = AgentSession.id(session); - const entry = this._sessions.get(sessionId) ?? await this._resumeSession(sessionId); - await entry.send(prompt, attachments); + await this._sessionSequencer.queue(sessionId, async () => { + const entry = this._sessions.get(sessionId) ?? await this._resumeSession(sessionId); + await entry.send(prompt, attachments); + }); } setPendingMessages(session: URI, steeringMessage: IPendingMessage | undefined, _queuedMessages: readonly IPendingMessage[]): void { @@ -236,15 +269,57 @@ export class CopilotAgent extends Disposable implements IAgent { async disposeSession(session: URI): Promise { const sessionId = AgentSession.id(session); - this._sessions.deleteAndDispose(sessionId); + await this._sessionSequencer.queue(sessionId, async () => { + this._sessions.deleteAndDispose(sessionId); + }); } async abortSession(session: URI): Promise { const sessionId = AgentSession.id(session); - const entry = this._sessions.get(sessionId); - if (entry) { - await entry.abort(); - } + await this._sessionSequencer.queue(sessionId, async () => { + const entry = this._sessions.get(sessionId); + if (entry) { + await entry.abort(); + } + }); + } + + async truncateSession(session: URI, turnIndex?: number): Promise { + const sessionId = AgentSession.id(session); + await this._sessionSequencer.queue(sessionId, async () => { + this._logService.info(`[Copilot:${sessionId}] Truncating session${turnIndex !== undefined ? ` at index ${turnIndex}` : ' (all turns)'}`); + + const keepUpToTurnIndex = turnIndex ?? -1; + + // Destroy the SDK session first and wait for cleanup to complete, + // ensuring on-disk data (events.jsonl, locks) is released before + // we modify it. Then dispose the wrapper. + const entry = this._sessions.get(sessionId); + if (entry) { + await entry.destroySession(); + } + this._sessions.deleteAndDispose(sessionId); + + if (keepUpToTurnIndex >= 0) { + const copilotDataDir = getCopilotDataDir(); + await truncateCopilotSessionOnDisk(copilotDataDir, sessionId, keepUpToTurnIndex); + } + + // Resume the session from the modified on-disk data + await this._resumeSession(sessionId); + this._logService.info(`[Copilot:${sessionId}] Session truncated and resumed`); + }); + } + + async forkSession(sourceSession: URI, newSessionId: string, turnIndex: number): Promise { + const sourceSessionId = AgentSession.id(sourceSession); + await this._sessionSequencer.queue(sourceSessionId, async () => { + this._logService.info(`[Copilot] Forking session ${sourceSessionId} at index ${turnIndex} → ${newSessionId}`); + + const copilotDataDir = getCopilotDataDir(); + await forkCopilotSessionOnDisk(copilotDataDir, sourceSessionId, newSessionId, turnIndex); + this._logService.info(`[Copilot] Forked session ${newSessionId} created on disk`); + }); } async changeModel(session: URI, model: string): Promise { @@ -284,20 +359,19 @@ export class CopilotAgent extends Disposable implements IAgent { * and returns it. The caller must call {@link CopilotAgentSession.initializeSession} * to wire up the SDK session. */ - private _createAgentSession(wrapperFactory: SessionWrapperFactory, workingDirectory: URI | undefined, sessionIdOverride?: string): CopilotAgentSession { - const rawId = sessionIdOverride ?? generateUuid(); - const sessionUri = AgentSession.uri(this.id, rawId); + private _createAgentSession(wrapperFactory: SessionWrapperFactory, workingDirectory: URI | undefined, sessionId: string): CopilotAgentSession { + const sessionUri = AgentSession.uri(this.id, sessionId); const agentSession = this._instantiationService.createInstance( CopilotAgentSession, sessionUri, - rawId, + sessionId, workingDirectory, this._onDidSessionProgress, wrapperFactory, ); - this._sessions.set(rawId, agentSession); + this._sessions.set(sessionId, agentSession); return agentSession; } diff --git a/src/vs/platform/agentHost/node/copilot/copilotAgentForking.ts b/src/vs/platform/agentHost/node/copilot/copilotAgentForking.ts new file mode 100644 index 00000000000..b8e662adc0c --- /dev/null +++ b/src/vs/platform/agentHost/node/copilot/copilotAgentForking.ts @@ -0,0 +1,565 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as fs from 'fs'; +import * as os from 'os'; +import type { Database } from '@vscode/sqlite3'; +import { generateUuid } from '../../../../base/common/uuid.js'; +import * as path from '../../../../base/common/path.js'; + +// ---- Types ------------------------------------------------------------------ + +/** + * A single event entry from a Copilot CLI `events.jsonl` file. + * The Copilot CLI stores session history as a newline-delimited JSON log + * where events form a linked list via `parentId`. + */ +export interface ICopilotEventLogEntry { + readonly type: string; + readonly data: Record; + readonly id: string; + readonly timestamp: string; + readonly parentId: string | null; +} + +// ---- Promise wrappers around callback-based @vscode/sqlite3 API ----------- + +function dbExec(db: Database, sql: string): Promise { + return new Promise((resolve, reject) => { + db.exec(sql, err => err ? reject(err) : resolve()); + }); +} + +function dbRun(db: Database, sql: string, params: unknown[]): Promise { + return new Promise((resolve, reject) => { + db.run(sql, params, function (err: Error | null) { + if (err) { + return reject(err); + } + resolve(); + }); + }); +} + +function dbAll(db: Database, sql: string, params: unknown[]): Promise[]> { + return new Promise((resolve, reject) => { + db.all(sql, params, (err: Error | null, rows: Record[]) => { + if (err) { + return reject(err); + } + resolve(rows); + }); + }); +} + +function dbClose(db: Database): Promise { + return new Promise((resolve, reject) => { + db.close(err => err ? reject(err) : resolve()); + }); +} + +function dbOpen(dbPath: string): Promise { + return new Promise((resolve, reject) => { + import('@vscode/sqlite3').then(sqlite3 => { + const db = new sqlite3.default.Database(dbPath, (err: Error | null) => { + if (err) { + return reject(err); + } + resolve(db); + }); + }, reject); + }); +} + +// ---- Pure functions (testable, no I/O) ------------------------------------ + +/** + * Parses a JSONL string into an array of event log entries. + */ +export function parseEventLog(content: string): ICopilotEventLogEntry[] { + const entries: ICopilotEventLogEntry[] = []; + for (const line of content.split('\n')) { + const trimmed = line.trim(); + if (trimmed.length === 0) { + continue; + } + entries.push(JSON.parse(trimmed)); + } + return entries; +} + +/** + * Serializes an array of event log entries back into a JSONL string. + */ +export function serializeEventLog(entries: readonly ICopilotEventLogEntry[]): string { + return entries.map(e => JSON.stringify(e)).join('\n') + '\n'; +} + +/** + * Finds the index of the last event that belongs to the given turn (0-based). + * + * A "turn" corresponds to one `user.message` event and all subsequent events + * up to (and including) the `assistant.turn_end` that closes that interaction, + * or the `session.shutdown` that ends the session. + * + * @returns The inclusive index of the last event in the specified turn, + * or `-1` if the turn is not found. + */ +export function findTurnBoundaryInEventLog(entries: readonly ICopilotEventLogEntry[], turnIndex: number): number { + let userMessageCount = -1; + let lastEventForTurn = -1; + + for (let i = 0; i < entries.length; i++) { + const entry = entries[i]; + + if (entry.type === 'user.message') { + userMessageCount++; + if (userMessageCount > turnIndex) { + // We've entered the next turn — stop + return lastEventForTurn; + } + } + + if (userMessageCount === turnIndex) { + lastEventForTurn = i; + } + } + + // If we scanned everything and the target turn was found, return its last event + return lastEventForTurn; +} + +/** + * Builds a forked event log from the source session's events. + * + * - Keeps events up to and including the specified fork turn (0-based). + * - Rewrites `session.start` with the new session ID. + * - Generates fresh UUIDs for all events. + * - Re-chains `parentId` links via an old→new ID map. + * - Strips `session.shutdown` and `session.resume` lifecycle events. + */ +export function buildForkedEventLog( + entries: readonly ICopilotEventLogEntry[], + forkTurnIndex: number, + newSessionId: string, +): ICopilotEventLogEntry[] { + const boundary = findTurnBoundaryInEventLog(entries, forkTurnIndex); + if (boundary < 0) { + throw new Error(`Fork turn index ${forkTurnIndex} not found in event log`); + } + + // Keep events up to boundary, filtering out lifecycle events + const kept = entries + .slice(0, boundary + 1) + .filter(e => e.type !== 'session.shutdown' && e.type !== 'session.resume'); + + // Build UUID remap and re-chain + const idMap = new Map(); + const result: ICopilotEventLogEntry[] = []; + + for (const entry of kept) { + const newId = generateUuid(); + idMap.set(entry.id, newId); + + let data = entry.data; + if (entry.type === 'session.start') { + data = { ...data, sessionId: newSessionId }; + } + + const newParentId = entry.parentId !== null + ? (idMap.get(entry.parentId) ?? null) + : null; + + result.push({ + type: entry.type, + data, + id: newId, + timestamp: entry.timestamp, + parentId: newParentId, + }); + } + + return result; +} + +/** + * Builds a truncated event log from the source session's events. + * + * - Keeps events up to and including the specified turn (0-based). + * - Prepends a new `session.start` event using the original start data. + * - Re-chains `parentId` links for remaining events. + */ +export function buildTruncatedEventLog( + entries: readonly ICopilotEventLogEntry[], + keepUpToTurnIndex: number, +): ICopilotEventLogEntry[] { + const boundary = findTurnBoundaryInEventLog(entries, keepUpToTurnIndex); + if (boundary < 0) { + throw new Error(`Turn index ${keepUpToTurnIndex} not found in event log`); + } + + // Find the original session.start for its metadata + const originalStart = entries.find(e => e.type === 'session.start'); + if (!originalStart) { + throw new Error('No session.start event found in event log'); + } + + // Keep events from after session start up to boundary, stripping lifecycle events + const kept = entries + .slice(0, boundary + 1) + .filter(e => e.type !== 'session.start' && e.type !== 'session.shutdown' && e.type !== 'session.resume'); + + // Build new start event + const newStartId = generateUuid(); + const newStart: ICopilotEventLogEntry = { + type: 'session.start', + data: { ...originalStart.data, startTime: new Date().toISOString() }, + id: newStartId, + timestamp: new Date().toISOString(), + parentId: null, + }; + + // Re-chain: first remaining event points to the new start + const idMap = new Map(); + idMap.set(originalStart.id, newStartId); + + const result: ICopilotEventLogEntry[] = [newStart]; + let lastId = newStartId; + + for (const entry of kept) { + const newId = generateUuid(); + idMap.set(entry.id, newId); + + const newParentId = entry.parentId !== null + ? (idMap.get(entry.parentId) ?? lastId) + : lastId; + + result.push({ + type: entry.type, + data: entry.data, + id: newId, + timestamp: entry.timestamp, + parentId: newParentId, + }); + lastId = newId; + } + + return result; +} + +/** + * Generates a `workspace.yaml` file content for a Copilot CLI session. + */ +export function buildWorkspaceYaml(sessionId: string, cwd: string, summary: string): string { + const now = new Date().toISOString(); + return [ + `id: ${sessionId}`, + `cwd: ${cwd}`, + `summary_count: 0`, + `created_at: ${now}`, + `updated_at: ${now}`, + `summary: ${summary}`, + '', + ].join('\n'); +} + +// ---- SQLite operations (Copilot CLI session-store.db) --------------------- + +/** + * Forks a session record in the Copilot CLI's `session-store.db`. + * + * Copies the source session's metadata, turns (up to `forkTurnIndex`), + * session files, search index entries, and checkpoints into a new session. + */ +export async function forkSessionInDb( + db: Database, + sourceSessionId: string, + newSessionId: string, + forkTurnIndex: number, +): Promise { + await dbExec(db, 'PRAGMA foreign_keys = ON'); + await dbExec(db, 'BEGIN TRANSACTION'); + try { + const now = new Date().toISOString(); + + // Copy session row + await dbRun(db, + `INSERT INTO sessions (id, cwd, repository, branch, summary, created_at, updated_at, host_type) + SELECT ?, cwd, repository, branch, summary, ?, ?, host_type + FROM sessions WHERE id = ?`, + [newSessionId, now, now, sourceSessionId], + ); + + // Copy turns up to fork point (turn_index is 0-based) + await dbRun(db, + `INSERT INTO turns (session_id, turn_index, user_message, assistant_response, timestamp) + SELECT ?, turn_index, user_message, assistant_response, timestamp + FROM turns + WHERE session_id = ? AND turn_index <= ?`, + [newSessionId, sourceSessionId, forkTurnIndex], + ); + + // Copy session files that were first seen at or before the fork point + await dbRun(db, + `INSERT INTO session_files (session_id, file_path, tool_name, turn_index, first_seen_at) + SELECT ?, file_path, tool_name, turn_index, first_seen_at + FROM session_files + WHERE session_id = ? AND turn_index <= ?`, + [newSessionId, sourceSessionId, forkTurnIndex], + ); + + // Copy search index entries for kept turns + await dbRun(db, + `INSERT INTO search_index (content, session_id, source_type, source_id) + SELECT content, ?, source_type, + REPLACE(source_id, ?, ?) + FROM search_index + WHERE session_id = ? + AND source_type = 'turn'`, + [newSessionId, sourceSessionId, newSessionId, sourceSessionId], + ); + + // Copy checkpoints at or before the fork point + await dbRun(db, + `INSERT INTO checkpoints (session_id, checkpoint_number, title, overview, history, work_done, technical_details, important_files, next_steps, created_at) + SELECT ?, checkpoint_number, title, overview, history, work_done, technical_details, important_files, next_steps, created_at + FROM checkpoints + WHERE session_id = ?`, + [newSessionId, sourceSessionId], + ); + + await dbExec(db, 'COMMIT'); + } catch (err) { + await dbExec(db, 'ROLLBACK'); + throw err; + } +} + +/** + * Truncates a session in the Copilot CLI's `session-store.db`. + * + * Removes all turns after `keepUpToTurnIndex` and updates session metadata. + */ +export async function truncateSessionInDb( + db: Database, + sessionId: string, + keepUpToTurnIndex: number, +): Promise { + await dbExec(db, 'PRAGMA foreign_keys = ON'); + await dbExec(db, 'BEGIN TRANSACTION'); + try { + const now = new Date().toISOString(); + + // Delete turns after the truncation point + await dbRun(db, + `DELETE FROM turns WHERE session_id = ? AND turn_index > ?`, + [sessionId, keepUpToTurnIndex], + ); + + // Update session timestamp + await dbRun(db, + `UPDATE sessions SET updated_at = ? WHERE id = ?`, + [now, sessionId], + ); + + // Remove search index entries for removed turns + // source_id format is ":turn:" + await dbAll(db, + `SELECT source_id FROM search_index + WHERE session_id = ? AND source_type = 'turn'`, + [sessionId], + ).then(async rows => { + const prefix = `${sessionId}:turn:`; + for (const row of rows) { + const sourceId = row.source_id as string; + if (sourceId.startsWith(prefix)) { + const turnIdx = parseInt(sourceId.substring(prefix.length), 10); + if (!isNaN(turnIdx) && turnIdx > keepUpToTurnIndex) { + await dbRun(db, + `DELETE FROM search_index WHERE source_id = ? AND session_id = ?`, + [sourceId, sessionId], + ); + } + } + } + }); + + await dbExec(db, 'COMMIT'); + } catch (err) { + await dbExec(db, 'ROLLBACK'); + throw err; + } +} + +// ---- File system operations ----------------------------------------------- + +/** + * Resolves the Copilot CLI data directory. + * The Copilot CLI stores its data in `~/.copilot/` by default, or in the + * directory specified by `COPILOT_CONFIG_DIR`. + */ +export function getCopilotDataDir(): string { + return process.env['COPILOT_CONFIG_DIR'] ?? path.join(os.homedir(), '.copilot'); +} + +/** + * Forks a Copilot CLI session on disk. + * + * 1. Reads the source session's `events.jsonl` + * 2. Builds a forked event log + * 3. Creates the new session folder with all required files/directories + * 4. Updates the `session-store.db` + * + * @param copilotDataDir Path to the `.copilot` directory + * @param sourceSessionId UUID of the source session to fork from + * @param newSessionId UUID for the new forked session + * @param forkTurnIndex 0-based turn index to fork at (inclusive) + */ +export async function forkCopilotSessionOnDisk( + copilotDataDir: string, + sourceSessionId: string, + newSessionId: string, + forkTurnIndex: number, +): Promise { + const sessionStateDir = path.join(copilotDataDir, 'session-state'); + + // Read source events + const sourceEventsPath = path.join(sessionStateDir, sourceSessionId, 'events.jsonl'); + const sourceContent = await fs.promises.readFile(sourceEventsPath, 'utf-8'); + const sourceEntries = parseEventLog(sourceContent); + + // Build forked event log + const forkedEntries = buildForkedEventLog(sourceEntries, forkTurnIndex, newSessionId); + + // Read source workspace.yaml for cwd/summary + let cwd = ''; + let summary = ''; + try { + const workspaceYamlPath = path.join(sessionStateDir, sourceSessionId, 'workspace.yaml'); + const yamlContent = await fs.promises.readFile(workspaceYamlPath, 'utf-8'); + const cwdMatch = yamlContent.match(/^cwd:\s*(.+)$/m); + const summaryMatch = yamlContent.match(/^summary:\s*(.+)$/m); + if (cwdMatch) { + cwd = cwdMatch[1].trim(); + } + if (summaryMatch) { + summary = summaryMatch[1].trim(); + } + } catch { + // Fall back to session.start data + const startEvent = sourceEntries.find(e => e.type === 'session.start'); + if (startEvent) { + const ctx = startEvent.data.context as Record | undefined; + cwd = ctx?.cwd ?? ''; + } + } + + // Create new session folder structure + const newSessionDir = path.join(sessionStateDir, newSessionId); + await fs.promises.mkdir(path.join(newSessionDir, 'checkpoints'), { recursive: true }); + await fs.promises.mkdir(path.join(newSessionDir, 'files'), { recursive: true }); + await fs.promises.mkdir(path.join(newSessionDir, 'research'), { recursive: true }); + + // Write events.jsonl + await fs.promises.writeFile( + path.join(newSessionDir, 'events.jsonl'), + serializeEventLog(forkedEntries), + 'utf-8', + ); + + // Write workspace.yaml + await fs.promises.writeFile( + path.join(newSessionDir, 'workspace.yaml'), + buildWorkspaceYaml(newSessionId, cwd, summary), + 'utf-8', + ); + + // Write empty vscode.metadata.json + await fs.promises.writeFile( + path.join(newSessionDir, 'vscode.metadata.json'), + '{}', + 'utf-8', + ); + + // Write empty checkpoints index + await fs.promises.writeFile( + path.join(newSessionDir, 'checkpoints', 'index.md'), + '', + 'utf-8', + ); + + // Update session-store.db + const dbPath = path.join(copilotDataDir, 'session-store.db'); + const db = await dbOpen(dbPath); + try { + await forkSessionInDb(db, sourceSessionId, newSessionId, forkTurnIndex); + } finally { + await dbClose(db); + } +} + +/** + * Truncates a Copilot CLI session on disk. + * + * 1. Reads the session's `events.jsonl` + * 2. Builds a truncated event log + * 3. Overwrites `events.jsonl` and updates `workspace.yaml` + * 4. Updates the `session-store.db` + * + * @param copilotDataDir Path to the `.copilot` directory + * @param sessionId UUID of the session to truncate + * @param keepUpToTurnIndex 0-based turn index to keep up to (inclusive) + */ +export async function truncateCopilotSessionOnDisk( + copilotDataDir: string, + sessionId: string, + keepUpToTurnIndex: number, +): Promise { + const sessionStateDir = path.join(copilotDataDir, 'session-state'); + const sessionDir = path.join(sessionStateDir, sessionId); + + // Read and truncate events + const eventsPath = path.join(sessionDir, 'events.jsonl'); + const content = await fs.promises.readFile(eventsPath, 'utf-8'); + const entries = parseEventLog(content); + const truncatedEntries = buildTruncatedEventLog(entries, keepUpToTurnIndex); + + // Overwrite events.jsonl + await fs.promises.writeFile(eventsPath, serializeEventLog(truncatedEntries), 'utf-8'); + + // Update workspace.yaml timestamp + try { + const yamlPath = path.join(sessionDir, 'workspace.yaml'); + let yaml = await fs.promises.readFile(yamlPath, 'utf-8'); + yaml = yaml.replace(/^updated_at:\s*.+$/m, `updated_at: ${new Date().toISOString()}`); + await fs.promises.writeFile(yamlPath, yaml, 'utf-8'); + } catch { + // workspace.yaml may not exist (old format) + } + + // Update session-store.db + const dbPath = path.join(copilotDataDir, 'session-store.db'); + const db = await dbOpen(dbPath); + try { + await truncateSessionInDb(db, sessionId, keepUpToTurnIndex); + } finally { + await dbClose(db); + } +} + +/** + * Maps a protocol turn ID to a 0-based turn index by finding the turn's + * position within the session's event log. + * + * The protocol state assigns arbitrary string IDs to turns, but the Copilot + * CLI's `events.jsonl` uses sequential `user.message` events. To bridge the + * two, we match turns by their position in the sequence. + * + * @returns The 0-based turn index, or `-1` if the turn ID is not found in the + * `turnIds` array. + */ +export function turnIdToIndex(turnIds: readonly string[], turnId: string): number { + return turnIds.indexOf(turnId); +} diff --git a/src/vs/platform/agentHost/node/copilot/copilotAgentSession.ts b/src/vs/platform/agentHost/node/copilot/copilotAgentSession.ts index 9b79f53d587..a2a095a7ab2 100644 --- a/src/vs/platform/agentHost/node/copilot/copilotAgentSession.ts +++ b/src/vs/platform/agentHost/node/copilot/copilotAgentSession.ts @@ -228,6 +228,16 @@ export class CopilotAgentSession extends Disposable { await this._wrapper.session.abort(); } + /** + * Explicitly destroys the underlying SDK session and waits for cleanup + * to complete. Call this before {@link dispose} when you need to ensure + * the session's on-disk data is no longer locked (e.g. before + * truncation or fork operations that modify the session files). + */ + async destroySession(): Promise { + await this._wrapper.session.destroy(); + } + async setModel(model: string): Promise { this._logService.info(`[Copilot:${this.sessionId}] Changing model to: ${model}`); await this._wrapper.session.setModel(model); diff --git a/src/vs/platform/agentHost/node/protocolServerHandler.ts b/src/vs/platform/agentHost/node/protocolServerHandler.ts index 640f5263103..43166b173bc 100644 --- a/src/vs/platform/agentHost/node/protocolServerHandler.ts +++ b/src/vs/platform/agentHost/node/protocolServerHandler.ts @@ -310,12 +310,24 @@ export class ProtocolServerHandler extends Disposable { }, createSession: async (_client, params) => { let createdSession: URI; + // Resolve fork turnId to a 0-based index using the source session's + // turn list in the state manager. + let fork: { session: URI; turnIndex: number } | undefined; + if (params.fork) { + const sourceState = this._stateManager.getSessionState(params.fork.session); + const turnIndex = sourceState?.turns.findIndex(t => t.id === params.fork!.turnId) ?? -1; + if (turnIndex < 0) { + throw new ProtocolError(AHP_PROVIDER_NOT_FOUND, `Fork turn ID ${params.fork.turnId} not found in session ${params.fork.session}`); + } + fork = { session: URI.parse(params.fork.session), turnIndex }; + } try { createdSession = await this._agentService.createSession({ provider: params.provider, model: params.model, workingDirectory: params.workingDirectory ? URI.parse(params.workingDirectory) : undefined, session: URI.parse(params.session), + fork, }); } catch (err) { if (err instanceof ProtocolError) { diff --git a/src/vs/platform/agentHost/test/node/copilotAgentForking.test.ts b/src/vs/platform/agentHost/test/node/copilotAgentForking.test.ts new file mode 100644 index 00000000000..2ea08be14ee --- /dev/null +++ b/src/vs/platform/agentHost/test/node/copilotAgentForking.test.ts @@ -0,0 +1,650 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import assert from 'assert'; +import { ensureNoDisposablesAreLeakedInTestSuite } from '../../../../base/test/common/utils.js'; +import { + parseEventLog, + serializeEventLog, + findTurnBoundaryInEventLog, + buildForkedEventLog, + buildTruncatedEventLog, + buildWorkspaceYaml, + forkSessionInDb, + truncateSessionInDb, + type ICopilotEventLogEntry, +} from '../../node/copilot/copilotAgentForking.js'; + +// ---- Test helpers ----------------------------------------------------------- + +function makeEntry(type: string, overrides?: Partial): ICopilotEventLogEntry { + return { + type, + data: {}, + id: `id-${Math.random().toString(36).slice(2, 8)}`, + timestamp: new Date().toISOString(), + parentId: null, + ...overrides, + }; +} + +/** + * Builds a minimal event log representing a multi-turn session. + * Each turn = user.message → assistant.turn_start → assistant.message → assistant.turn_end. + */ +function buildTestEventLog(turnCount: number): ICopilotEventLogEntry[] { + const entries: ICopilotEventLogEntry[] = []; + let lastId: string | null = null; + + const sessionStart = makeEntry('session.start', { + id: 'session-start-id', + data: { sessionId: 'source-session', context: { cwd: '/test' } }, + parentId: null, + }); + entries.push(sessionStart); + lastId = sessionStart.id; + + for (let turn = 0; turn < turnCount; turn++) { + const userMsg = makeEntry('user.message', { + id: `user-msg-${turn}`, + data: { content: `Turn ${turn} message` }, + parentId: lastId, + }); + entries.push(userMsg); + lastId = userMsg.id; + + const turnStart = makeEntry('assistant.turn_start', { + id: `turn-start-${turn}`, + data: { turnId: String(turn) }, + parentId: lastId, + }); + entries.push(turnStart); + lastId = turnStart.id; + + const assistantMsg = makeEntry('assistant.message', { + id: `assistant-msg-${turn}`, + data: { content: `Response ${turn}` }, + parentId: lastId, + }); + entries.push(assistantMsg); + lastId = assistantMsg.id; + + const turnEnd = makeEntry('assistant.turn_end', { + id: `turn-end-${turn}`, + parentId: lastId, + }); + entries.push(turnEnd); + lastId = turnEnd.id; + } + + return entries; +} + +suite('CopilotAgentForking', () => { + + ensureNoDisposablesAreLeakedInTestSuite(); + + // ---- parseEventLog / serializeEventLog ------------------------------ + + suite('parseEventLog', () => { + + test('parses a single-line JSONL', () => { + const entry = makeEntry('session.start'); + const jsonl = JSON.stringify(entry); + const result = parseEventLog(jsonl); + assert.strictEqual(result.length, 1); + assert.strictEqual(result[0].type, 'session.start'); + }); + + test('parses multi-line JSONL', () => { + const entries = [makeEntry('session.start'), makeEntry('user.message')]; + const jsonl = entries.map(e => JSON.stringify(e)).join('\n'); + const result = parseEventLog(jsonl); + assert.strictEqual(result.length, 2); + }); + + test('ignores empty lines', () => { + const entry = makeEntry('session.start'); + const jsonl = '\n' + JSON.stringify(entry) + '\n\n'; + const result = parseEventLog(jsonl); + assert.strictEqual(result.length, 1); + }); + + test('empty input returns empty array', () => { + assert.deepStrictEqual(parseEventLog(''), []); + assert.deepStrictEqual(parseEventLog('\n\n'), []); + }); + }); + + suite('serializeEventLog', () => { + + test('round-trips correctly', () => { + const entries = buildTestEventLog(2); + const serialized = serializeEventLog(entries); + const parsed = parseEventLog(serialized); + assert.strictEqual(parsed.length, entries.length); + for (let i = 0; i < entries.length; i++) { + assert.strictEqual(parsed[i].id, entries[i].id); + assert.strictEqual(parsed[i].type, entries[i].type); + } + }); + + test('ends with a newline', () => { + const entries = [makeEntry('session.start')]; + const serialized = serializeEventLog(entries); + assert.ok(serialized.endsWith('\n')); + }); + }); + + // ---- findTurnBoundaryInEventLog ------------------------------------- + + suite('findTurnBoundaryInEventLog', () => { + + test('finds first turn boundary', () => { + const entries = buildTestEventLog(3); + const boundary = findTurnBoundaryInEventLog(entries, 0); + // Turn 0: user.message(1) + turn_start(2) + assistant.message(3) + turn_end(4) + // Index 4 = turn_end of turn 0 + assert.strictEqual(boundary, 4); + assert.strictEqual(entries[boundary].type, 'assistant.turn_end'); + assert.strictEqual(entries[boundary].id, 'turn-end-0'); + }); + + test('finds middle turn boundary', () => { + const entries = buildTestEventLog(3); + const boundary = findTurnBoundaryInEventLog(entries, 1); + // Turn 1 ends at index 8 + assert.strictEqual(boundary, 8); + assert.strictEqual(entries[boundary].type, 'assistant.turn_end'); + assert.strictEqual(entries[boundary].id, 'turn-end-1'); + }); + + test('finds last turn boundary', () => { + const entries = buildTestEventLog(3); + const boundary = findTurnBoundaryInEventLog(entries, 2); + assert.strictEqual(boundary, entries.length - 1); + assert.strictEqual(entries[boundary].type, 'assistant.turn_end'); + assert.strictEqual(entries[boundary].id, 'turn-end-2'); + }); + + test('returns -1 for non-existent turn', () => { + const entries = buildTestEventLog(2); + assert.strictEqual(findTurnBoundaryInEventLog(entries, 5), -1); + }); + + test('returns -1 for empty log', () => { + assert.strictEqual(findTurnBoundaryInEventLog([], 0), -1); + }); + }); + + // ---- buildForkedEventLog -------------------------------------------- + + suite('buildForkedEventLog', () => { + + test('forks at turn 0', () => { + const entries = buildTestEventLog(3); + const forked = buildForkedEventLog(entries, 0, 'new-session-id'); + + // Should have session.start + turn 0 events (user.message, turn_start, assistant.message, turn_end) + assert.strictEqual(forked.length, 5); + assert.strictEqual(forked[0].type, 'session.start'); + assert.strictEqual((forked[0].data as Record).sessionId, 'new-session-id'); + }); + + test('forks at turn 1', () => { + const entries = buildTestEventLog(3); + const forked = buildForkedEventLog(entries, 1, 'new-session-id'); + + // session.start + 2 turns × 4 events = 9 events + assert.strictEqual(forked.length, 9); + }); + + test('generates unique UUIDs', () => { + const entries = buildTestEventLog(2); + const forked = buildForkedEventLog(entries, 0, 'new-session-id'); + + const ids = new Set(forked.map(e => e.id)); + assert.strictEqual(ids.size, forked.length, 'All IDs should be unique'); + + // None should match the original + for (const entry of forked) { + assert.ok(!entries.some(e => e.id === entry.id), 'Should not reuse original IDs'); + } + }); + + test('re-chains parentId links', () => { + const entries = buildTestEventLog(2); + const forked = buildForkedEventLog(entries, 0, 'new-session-id'); + + // First event has no parent + assert.strictEqual(forked[0].parentId, null); + + // Each subsequent event's parentId should be a valid ID in the forked log + const idSet = new Set(forked.map(e => e.id)); + for (let i = 1; i < forked.length; i++) { + assert.ok( + forked[i].parentId !== null && idSet.has(forked[i].parentId!), + `Event ${i} (${forked[i].type}) should have a valid parentId`, + ); + } + }); + + test('strips session.shutdown and session.resume events', () => { + const entries = buildTestEventLog(2); + // Insert lifecycle events + entries.splice(5, 0, makeEntry('session.shutdown', { parentId: entries[4].id })); + entries.splice(6, 0, makeEntry('session.resume', { parentId: entries[5].id })); + + const forked = buildForkedEventLog(entries, 1, 'new-session-id'); + assert.ok(!forked.some(e => e.type === 'session.shutdown')); + assert.ok(!forked.some(e => e.type === 'session.resume')); + }); + + test('throws for invalid turn index', () => { + const entries = buildTestEventLog(1); + assert.throws(() => buildForkedEventLog(entries, 5, 'new-session-id')); + }); + }); + + // ---- buildTruncatedEventLog ----------------------------------------- + + suite('buildTruncatedEventLog', () => { + + test('truncates to turn 0', () => { + const entries = buildTestEventLog(3); + const truncated = buildTruncatedEventLog(entries, 0); + + // New session.start + turn 0 events (user.message, turn_start, assistant.message, turn_end) + assert.strictEqual(truncated.length, 5); + assert.strictEqual(truncated[0].type, 'session.start'); + }); + + test('truncates to turn 1', () => { + const entries = buildTestEventLog(3); + const truncated = buildTruncatedEventLog(entries, 1); + + // New session.start + 2 turns × 4 events = 9 events + assert.strictEqual(truncated.length, 9); + }); + + test('prepends fresh session.start', () => { + const entries = buildTestEventLog(2); + const truncated = buildTruncatedEventLog(entries, 0); + + assert.strictEqual(truncated[0].type, 'session.start'); + assert.strictEqual(truncated[0].parentId, null); + // Should not reuse original session.start ID + assert.notStrictEqual(truncated[0].id, entries[0].id); + }); + + test('re-chains parentId links', () => { + const entries = buildTestEventLog(2); + const truncated = buildTruncatedEventLog(entries, 0); + + const idSet = new Set(truncated.map(e => e.id)); + for (let i = 1; i < truncated.length; i++) { + assert.ok( + truncated[i].parentId !== null && idSet.has(truncated[i].parentId!), + `Event ${i} (${truncated[i].type}) should have a valid parentId`, + ); + } + }); + + test('strips lifecycle events', () => { + const entries = buildTestEventLog(3); + // Add lifecycle events between turns + entries.splice(5, 0, makeEntry('session.shutdown')); + entries.splice(6, 0, makeEntry('session.resume')); + + const truncated = buildTruncatedEventLog(entries, 2); + const lifecycleEvents = truncated.filter( + e => e.type === 'session.shutdown' || e.type === 'session.resume', + ); + assert.strictEqual(lifecycleEvents.length, 0); + }); + + test('throws for invalid turn index', () => { + const entries = buildTestEventLog(1); + assert.throws(() => buildTruncatedEventLog(entries, 5)); + }); + + test('throws when no session.start exists', () => { + const entries = [makeEntry('user.message')]; + assert.throws(() => buildTruncatedEventLog(entries, 0)); + }); + }); + + // ---- buildWorkspaceYaml --------------------------------------------- + + suite('buildWorkspaceYaml', () => { + + test('contains required fields', () => { + const yaml = buildWorkspaceYaml('test-id', '/home/user/project', 'Test summary'); + assert.ok(yaml.includes('id: test-id')); + assert.ok(yaml.includes('cwd: /home/user/project')); + assert.ok(yaml.includes('summary: Test summary')); + assert.ok(yaml.includes('summary_count: 0')); + assert.ok(yaml.includes('created_at:')); + assert.ok(yaml.includes('updated_at:')); + }); + }); + + // ---- SQLite operations (in-memory) ---------------------------------- + + suite('forkSessionInDb', () => { + + async function openTestDb(): Promise { + const sqlite3 = await import('@vscode/sqlite3'); + return new Promise((resolve, reject) => { + const db = new sqlite3.default.Database(':memory:', (err: Error | null) => { + if (err) { + return reject(err); + } + resolve(db); + }); + }); + } + + function exec(db: import('@vscode/sqlite3').Database, sql: string): Promise { + return new Promise((resolve, reject) => { + db.exec(sql, err => err ? reject(err) : resolve()); + }); + } + + function all(db: import('@vscode/sqlite3').Database, sql: string, params: unknown[] = []): Promise[]> { + return new Promise((resolve, reject) => { + db.all(sql, params, (err: Error | null, rows: Record[]) => { + if (err) { + return reject(err); + } + resolve(rows); + }); + }); + } + + function close(db: import('@vscode/sqlite3').Database): Promise { + return new Promise((resolve, reject) => { + db.close(err => err ? reject(err) : resolve()); + }); + } + + async function setupSchema(db: import('@vscode/sqlite3').Database): Promise { + await exec(db, ` + CREATE TABLE sessions ( + id TEXT PRIMARY KEY, + cwd TEXT, + repository TEXT, + branch TEXT, + summary TEXT, + created_at TEXT, + updated_at TEXT, + host_type TEXT + ); + CREATE TABLE turns ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + turn_index INTEGER NOT NULL, + user_message TEXT, + assistant_response TEXT, + timestamp TEXT, + UNIQUE(session_id, turn_index) + ); + CREATE TABLE session_files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + file_path TEXT, + tool_name TEXT, + turn_index INTEGER, + first_seen_at TEXT + ); + CREATE VIRTUAL TABLE search_index USING fts5( + content, + session_id, + source_type, + source_id + ); + CREATE TABLE checkpoints ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + checkpoint_number INTEGER, + title TEXT, + overview TEXT, + history TEXT, + work_done TEXT, + technical_details TEXT, + important_files TEXT, + next_steps TEXT, + created_at TEXT + ); + `); + } + + async function seedTestData(db: import('@vscode/sqlite3').Database, sessionId: string, turnCount: number): Promise { + await exec(db, ` + INSERT INTO sessions (id, cwd, repository, branch, summary, created_at, updated_at, host_type) + VALUES ('${sessionId}', '/test', 'test-repo', 'main', 'Test session', '2026-01-01', '2026-01-01', 'github'); + `); + for (let i = 0; i < turnCount; i++) { + await exec(db, ` + INSERT INTO turns (session_id, turn_index, user_message, assistant_response, timestamp) + VALUES ('${sessionId}', ${i}, 'msg ${i}', 'resp ${i}', '2026-01-01'); + `); + await exec(db, ` + INSERT INTO session_files (session_id, file_path, tool_name, turn_index, first_seen_at) + VALUES ('${sessionId}', 'file${i}.ts', 'edit', ${i}, '2026-01-01'); + `); + } + } + + test('copies session metadata', async () => { + const db = await openTestDb(); + try { + await setupSchema(db); + await seedTestData(db, 'source', 3); + + await forkSessionInDb(db, 'source', 'forked', 1); + + const sessions = await all(db, 'SELECT * FROM sessions WHERE id = ?', ['forked']); + assert.strictEqual(sessions.length, 1); + assert.strictEqual(sessions[0].cwd, '/test'); + assert.strictEqual(sessions[0].repository, 'test-repo'); + } finally { + await close(db); + } + }); + + test('copies turns up to fork point', async () => { + const db = await openTestDb(); + try { + await setupSchema(db); + await seedTestData(db, 'source', 3); + + await forkSessionInDb(db, 'source', 'forked', 1); + + const turns = await all(db, 'SELECT * FROM turns WHERE session_id = ? ORDER BY turn_index', ['forked']); + assert.strictEqual(turns.length, 2); // turns 0 and 1 + assert.strictEqual(turns[0].turn_index, 0); + assert.strictEqual(turns[1].turn_index, 1); + } finally { + await close(db); + } + }); + + test('copies session files up to fork point', async () => { + const db = await openTestDb(); + try { + await setupSchema(db); + await seedTestData(db, 'source', 3); + + await forkSessionInDb(db, 'source', 'forked', 1); + + const files = await all(db, 'SELECT * FROM session_files WHERE session_id = ?', ['forked']); + assert.strictEqual(files.length, 2); // files from turns 0 and 1 + } finally { + await close(db); + } + }); + + test('does not affect source session', async () => { + const db = await openTestDb(); + try { + await setupSchema(db); + await seedTestData(db, 'source', 3); + + await forkSessionInDb(db, 'source', 'forked', 1); + + const sourceTurns = await all(db, 'SELECT * FROM turns WHERE session_id = ?', ['source']); + assert.strictEqual(sourceTurns.length, 3); + } finally { + await close(db); + } + }); + }); + + suite('truncateSessionInDb', () => { + + async function openTestDb(): Promise { + const sqlite3 = await import('@vscode/sqlite3'); + return new Promise((resolve, reject) => { + const db = new sqlite3.default.Database(':memory:', (err: Error | null) => { + if (err) { + return reject(err); + } + resolve(db); + }); + }); + } + + function exec(db: import('@vscode/sqlite3').Database, sql: string): Promise { + return new Promise((resolve, reject) => { + db.exec(sql, err => err ? reject(err) : resolve()); + }); + } + + function all(db: import('@vscode/sqlite3').Database, sql: string, params: unknown[] = []): Promise[]> { + return new Promise((resolve, reject) => { + db.all(sql, params, (err: Error | null, rows: Record[]) => { + if (err) { + return reject(err); + } + resolve(rows); + }); + }); + } + + function close(db: import('@vscode/sqlite3').Database): Promise { + return new Promise((resolve, reject) => { + db.close(err => err ? reject(err) : resolve()); + }); + } + + async function setupSchema(db: import('@vscode/sqlite3').Database): Promise { + await exec(db, ` + CREATE TABLE sessions ( + id TEXT PRIMARY KEY, + cwd TEXT, + repository TEXT, + branch TEXT, + summary TEXT, + created_at TEXT, + updated_at TEXT, + host_type TEXT + ); + CREATE TABLE turns ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + turn_index INTEGER NOT NULL, + user_message TEXT, + assistant_response TEXT, + timestamp TEXT, + UNIQUE(session_id, turn_index) + ); + CREATE VIRTUAL TABLE search_index USING fts5( + content, + session_id, + source_type, + source_id + ); + `); + } + + test('removes turns after truncation point', async () => { + const db = await openTestDb(); + try { + await setupSchema(db); + await exec(db, ` + INSERT INTO sessions (id, cwd, summary, created_at, updated_at) + VALUES ('sess', '/test', 'Test', '2026-01-01', '2026-01-01'); + `); + for (let i = 0; i < 5; i++) { + await exec(db, ` + INSERT INTO turns (session_id, turn_index, user_message, timestamp) + VALUES ('sess', ${i}, 'msg ${i}', '2026-01-01'); + `); + } + + await truncateSessionInDb(db, 'sess', 2); + + const turns = await all(db, 'SELECT * FROM turns WHERE session_id = ? ORDER BY turn_index', ['sess']); + assert.strictEqual(turns.length, 3); // turns 0, 1, 2 + assert.strictEqual(turns[0].turn_index, 0); + assert.strictEqual(turns[2].turn_index, 2); + } finally { + await close(db); + } + }); + + test('updates session timestamp', async () => { + const db = await openTestDb(); + try { + await setupSchema(db); + await exec(db, ` + INSERT INTO sessions (id, cwd, summary, created_at, updated_at) + VALUES ('sess', '/test', 'Test', '2026-01-01', '2026-01-01'); + `); + await exec(db, ` + INSERT INTO turns (session_id, turn_index, user_message, timestamp) + VALUES ('sess', 0, 'msg 0', '2026-01-01'); + `); + + await truncateSessionInDb(db, 'sess', 0); + + const sessions = await all(db, 'SELECT updated_at FROM sessions WHERE id = ?', ['sess']); + assert.notStrictEqual(sessions[0].updated_at, '2026-01-01'); + } finally { + await close(db); + } + }); + + test('removes search index entries for truncated turns', async () => { + const db = await openTestDb(); + try { + await setupSchema(db); + await exec(db, ` + INSERT INTO sessions (id, cwd, summary, created_at, updated_at) + VALUES ('sess', '/test', 'Test', '2026-01-01', '2026-01-01'); + `); + for (let i = 0; i < 3; i++) { + await exec(db, ` + INSERT INTO turns (session_id, turn_index, user_message, timestamp) + VALUES ('sess', ${i}, 'msg ${i}', '2026-01-01'); + `); + await exec(db, ` + INSERT INTO search_index (content, session_id, source_type, source_id) + VALUES ('content ${i}', 'sess', 'turn', 'sess:turn:${i}'); + `); + } + + await truncateSessionInDb(db, 'sess', 0); + + const searchEntries = await all(db, 'SELECT * FROM search_index WHERE session_id = ?', ['sess']); + assert.strictEqual(searchEntries.length, 1); + assert.strictEqual(searchEntries[0].source_id, 'sess:turn:0'); + } finally { + await close(db); + } + }); + }); +}); diff --git a/src/vs/sessions/contrib/remoteAgentHost/browser/remoteAgentHost.contribution.ts b/src/vs/sessions/contrib/remoteAgentHost/browser/remoteAgentHost.contribution.ts index c167b8ebadd..7a47d3c412b 100644 --- a/src/vs/sessions/contrib/remoteAgentHost/browser/remoteAgentHost.contribution.ts +++ b/src/vs/sessions/contrib/remoteAgentHost/browser/remoteAgentHost.contribution.ts @@ -346,6 +346,9 @@ export class RemoteAgentHostContribution extends Disposable implements IWorkbenc canDelegate: true, requiresCustomModels: true, supportsDelegation: false, + capabilities: { + supportsCheckpoints: true, + }, })); // Session handler (unified) diff --git a/src/vs/workbench/contrib/chat/browser/actions/chatForkActions.ts b/src/vs/workbench/contrib/chat/browser/actions/chatForkActions.ts index 57b337a8ccc..a10325c2770 100644 --- a/src/vs/workbench/contrib/chat/browser/actions/chatForkActions.ts +++ b/src/vs/workbench/contrib/chat/browser/actions/chatForkActions.ts @@ -12,7 +12,7 @@ import { localize, localize2 } from '../../../../../nls.js'; import { Action2, MenuId, registerAction2 } from '../../../../../platform/actions/common/actions.js'; import { ContextKeyExpr } from '../../../../../platform/contextkey/common/contextkey.js'; import { ServicesAccessor } from '../../../../../platform/instantiation/common/instantiation.js'; -import { ChatContextKeys } from '../../common/actions/chatContextKeys.js'; +import { ChatContextKeyExprs, ChatContextKeys } from '../../common/actions/chatContextKeys.js'; import { IChatService, ResponseModelState } from '../../common/chatService/chatService.js'; import type { ISerializableChatData } from '../../common/model/chatModel.js'; import { isChatTreeItem, isRequestVM, isResponseVM } from '../../common/model/chatViewModel.js'; @@ -40,7 +40,7 @@ export function registerChatForkActions() { ChatContextKeys.isRequest, ChatContextKeys.isFirstRequest.negate(), ContextKeyExpr.or( - ChatContextKeys.lockedToCodingAgent.negate(), + ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession), ChatContextKeys.chatSessionSupportsFork ) ) diff --git a/src/vs/workbench/contrib/chat/browser/actions/chatNewActions.ts b/src/vs/workbench/contrib/chat/browser/actions/chatNewActions.ts index cdc8b4e2858..a76f84dd159 100644 --- a/src/vs/workbench/contrib/chat/browser/actions/chatNewActions.ts +++ b/src/vs/workbench/contrib/chat/browser/actions/chatNewActions.ts @@ -14,7 +14,7 @@ import { ContextKeyExpr } from '../../../../../platform/contextkey/common/contex import { IDialogService } from '../../../../../platform/dialogs/common/dialogs.js'; import { KeybindingWeight } from '../../../../../platform/keybinding/common/keybindingsRegistry.js'; import { IViewsService } from '../../../../services/views/common/viewsService.js'; -import { ChatContextKeys } from '../../common/actions/chatContextKeys.js'; +import { ChatContextKeyExprs, ChatContextKeys } from '../../common/actions/chatContextKeys.js'; import { IChatEditingSession } from '../../common/editing/chatEditingService.js'; import { IChatService } from '../../common/chatService/chatService.js'; import { ChatAgentLocation, ChatConfiguration, ChatModeKind } from '../../common/constants.js'; @@ -278,7 +278,7 @@ export function registerNewChatActions() { f1: true, menu: [{ id: MenuId.ChatMessageRestoreCheckpoint, - when: ChatContextKeys.lockedToCodingAgent.negate(), + when: ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession), group: 'navigation', order: -1 }] diff --git a/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostChatContribution.ts b/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostChatContribution.ts index eb91bcc9797..4e1b931dbd7 100644 --- a/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostChatContribution.ts +++ b/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostChatContribution.ts @@ -147,6 +147,9 @@ export class AgentHostContribution extends Disposable implements IWorkbenchContr description: agent.description, canDelegate: true, requiresCustomModels: true, + capabilities: { + supportsCheckpoints: true, + }, })); // Session list controller diff --git a/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostEditingSession.ts b/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostEditingSession.ts index 019e7459e33..d604f351efb 100644 --- a/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostEditingSession.ts +++ b/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostEditingSession.ts @@ -46,7 +46,8 @@ interface IAgentHostFileEdit { interface IAgentHostCheckpoint { readonly requestId: string; - readonly undoStopId: string; + /** Tool-call ID, or `undefined` for the sentinel checkpoint at request start. */ + readonly undoStopId: string | undefined; readonly edits: IAgentHostFileEdit[]; } @@ -120,7 +121,26 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS private readonly _entriesObs = observableValue(this, []); readonly entries: IObservable = this._entriesObs; - readonly requestDisablement: IObservable = constObservable([]); + readonly requestDisablement: IObservable = derivedOpts( + { equalsFn: (a, b) => a.length === b.length && a.every((v, i) => v.requestId === b[i].requestId && v.afterUndoStop === b[i].afterUndoStop) }, + reader => { + const currentIdx = this._currentCheckpointIndex.read(reader); + if (currentIdx >= this._checkpoints.length - 1) { + return []; + } + // Collect unique request IDs from checkpoints after the current + // index. Keep the first entry per request — if that's the sentinel + // (undoStopId === undefined) the entire request is disabled. + const disabled = new Map(); + for (let i = currentIdx + 1; i < this._checkpoints.length; i++) { + const cp = this._checkpoints[i]; + if (!disabled.has(cp.requestId)) { + disabled.set(cp.requestId, cp.undoStopId); + } + } + return [...disabled].map(([requestId, afterUndoStop]): IChatRequestDisablement => ({ requestId, afterUndoStop })); + }, + ); private readonly _onDidDispose = this._register(new Emitter()); readonly onDidDispose: Event = this._onDidDispose.event; @@ -154,6 +174,27 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS // ---- Hydration from protocol state -------------------------------------- + /** + * Ensures a sentinel checkpoint exists for the given request. Called at the + * start of every turn so that `requestDisablement` and `restoreSnapshot` + * can reference requests that may not produce any file edits. + * + * Also splices away stale checkpoints after the current index (undo branch + * semantics) when a new request arrives after a checkpoint restore. + */ + ensureRequestCheckpoint(requestId: string): void { + // Splice stale checkpoints if the user restored a checkpoint + const currentIdx = this._currentCheckpointIndex.get(); + if (currentIdx < this._checkpoints.length - 1) { + this._checkpoints.splice(currentIdx + 1); + } + + // Insert sentinel for this request if it doesn't exist yet + if (!this._checkpoints.some(cp => cp.requestId === requestId)) { + this._checkpoints.push({ requestId, undoStopId: undefined, edits: [] }); + } + } + addToolCallEdits(requestId: string, tc: IToolCallState): IChatProgress[] { if (tc.status !== ToolCallStatus.Completed) { return []; @@ -164,6 +205,9 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS return []; } + // Ensure the sentinel and undo-branch splice are handled + this.ensureRequestCheckpoint(requestId); + const fileEdits = fileEditsToExternalEdits(tc); if (fileEdits.length === 0) { return []; @@ -250,43 +294,54 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS if (stopId !== undefined) { return this._checkpoints.findIndex(cp => cp.requestId === requestId && cp.undoStopId === stopId); } - // No specific stop: use the last checkpoint for this request - for (let i = this._checkpoints.length - 1; i >= 0; i--) { - if (this._checkpoints[i].requestId === requestId) { - return i; - } - } - return -1; + // No specific stop: find the sentinel checkpoint (undoStopId === undefined) + // for this request, which marks the request boundary. + return this._checkpoints.findIndex(cp => cp.requestId === requestId && cp.undoStopId === undefined); } private _findCheckpoint(requestId: string, stopId: string | undefined): IAgentHostCheckpoint | undefined { - const idx = this._findCheckpointIndex(requestId, stopId); - return idx >= 0 ? this._checkpoints[idx] : undefined; + if (stopId !== undefined) { + const idx = this._findCheckpointIndex(requestId, stopId); + return idx >= 0 ? this._checkpoints[idx] : undefined; + } + // No specific stop: find the last non-sentinel checkpoint for this + // request (the one with actual edits). + for (let i = this._checkpoints.length - 1; i >= 0; i--) { + const cp = this._checkpoints[i]; + if (cp.requestId === requestId && cp.undoStopId !== undefined) { + return cp; + } + } + return undefined; } async restoreSnapshot(requestId: string, stopId: string | undefined): Promise { - const idx = this._findCheckpointIndex(requestId, stopId); - if (idx < 0) { + const cpIdx = this._findCheckpointIndex(requestId, stopId); + if (cpIdx < 0) { this._logService.warn(`[AgentHostEditingSession] No checkpoint found for requestId=${requestId}${stopId ? `, stopId=${stopId}` : ''}`); return; } + // When stopId is undefined we found the sentinel (request boundary). + // Navigate to one before it so the request's edits are fully undone. + const targetIdx = stopId === undefined ? cpIdx - 1 : cpIdx; + // Navigate to the target checkpoint const currentIdx = this._currentCheckpointIndex.get(); - if (idx < currentIdx) { + if (targetIdx < currentIdx) { // Undo forward checkpoints - for (let i = currentIdx; i > idx; i--) { + for (let i = currentIdx; i > targetIdx; i--) { await this._writeCheckpointContent(this._checkpoints[i], 'before'); } - } else if (idx > currentIdx) { + } else if (targetIdx > currentIdx) { // Redo to reach the target - for (let i = currentIdx + 1; i <= idx; i++) { + for (let i = currentIdx + 1; i <= targetIdx; i++) { await this._writeCheckpointContent(this._checkpoints[i], 'after'); } } transaction(tx => { - this._currentCheckpointIndex.set(idx, tx); + this._currentCheckpointIndex.set(targetIdx, tx); }); this._rebuildEntries(); } @@ -558,8 +613,14 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS await this._writeCheckpointContent(this._checkpoints[idx], 'before'); + // Skip past any sentinel checkpoints (they have no edits) + let newIdx = idx - 1; + while (newIdx >= 0 && this._checkpoints[newIdx].undoStopId === undefined) { + newIdx--; + } + transaction(tx => { - this._currentCheckpointIndex.set(idx - 1, tx); + this._currentCheckpointIndex.set(newIdx, tx); }); this._rebuildEntries(); } @@ -570,7 +631,15 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS return; } - const nextIdx = idx + 1; + // Skip past sentinel checkpoints to the next tool-call checkpoint + let nextIdx = idx + 1; + while (nextIdx < this._checkpoints.length && this._checkpoints[nextIdx].undoStopId === undefined) { + nextIdx++; + } + if (nextIdx >= this._checkpoints.length) { + return; + } + await this._writeCheckpointContent(this._checkpoints[nextIdx], 'after'); transaction(tx => { diff --git a/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostSessionHandler.ts b/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostSessionHandler.ts index 12758f614c7..9106ded0cf8 100644 --- a/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostSessionHandler.ts +++ b/src/vs/workbench/contrib/chat/browser/agentSessions/agentHost/agentHostSessionHandler.ts @@ -12,9 +12,9 @@ import { ResourceMap } from '../../../../../../base/common/map.js'; import { observableValue } from '../../../../../../base/common/observable.js'; import { isEqual } from '../../../../../../base/common/resources.js'; import { URI } from '../../../../../../base/common/uri.js'; -import { generateUuid } from '../../../../../../base/common/uuid.js'; import { localize } from '../../../../../../nls.js'; import { AgentProvider, AgentSession, IAgentAttachment, type IAgentConnection } from '../../../../../../platform/agentHost/common/agentService.js'; +import { ISessionTruncatedAction } from '../../../../../../platform/agentHost/common/state/protocol/actions.js'; import { ActionType, isSessionAction, type ISessionAction } from '../../../../../../platform/agentHost/common/state/sessionActions.js'; import { SessionClientState } from '../../../../../../platform/agentHost/common/state/sessionClientState.js'; import { AHP_AUTH_REQUIRED, ProtocolError } from '../../../../../../platform/agentHost/common/state/sessionProtocol.js'; @@ -26,7 +26,7 @@ import { ILogService } from '../../../../../../platform/log/common/log.js'; import { IProductService } from '../../../../../../platform/product/common/productService.js'; import { IWorkspaceContextService } from '../../../../../../platform/workspace/common/workspace.js'; import { ChatRequestQueueKind, IChatProgress, IChatService, IChatToolInvocation, ToolConfirmKind } from '../../../common/chatService/chatService.js'; -import { IChatSession, IChatSessionContentProvider, IChatSessionHistoryItem } from '../../../common/chatSessionsService.js'; +import { IChatSession, IChatSessionContentProvider, IChatSessionHistoryItem, IChatSessionItem, IChatSessionRequestHistoryItem } from '../../../common/chatSessionsService.js'; import { ChatAgentLocation, ChatModeKind } from '../../../common/constants.js'; import { IChatEditingService } from '../../../common/editing/chatEditingService.js'; import { ChatToolInvocation } from '../../../common/model/chatProgressTypes/chatToolInvocation.js'; @@ -59,11 +59,13 @@ class AgentHostChatSession extends Disposable implements IChatSession { readonly requestHandler: IChatSession['requestHandler']; interruptActiveResponseCallback: IChatSession['interruptActiveResponseCallback']; + readonly forkSession: IChatSession['forkSession']; constructor( readonly sessionResource: URI, readonly history: readonly IChatSessionHistoryItem[], private readonly _sendRequest: (request: IChatAgentRequest, progress: (parts: IChatProgress[]) => void, token: CancellationToken) => Promise, + private readonly _forkSession: ((request: IChatSessionRequestHistoryItem | undefined, token: CancellationToken) => Promise) | undefined, initialProgress: IChatProgress[] | undefined, onDispose: () => void, @ILogService private readonly _logService: ILogService, @@ -91,6 +93,8 @@ class AgentHostChatSession extends Disposable implements IChatSession { this.interruptActiveResponseCallback = (hasActiveTurn || history.length === 0) ? async () => { return true; } : undefined; + + this.forkSession = this._forkSession; } /** @@ -291,6 +295,9 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC this._ensurePendingMessageSubscription(sessionResource, backendSession); return this._handleTurn(backendSession, request, progress, token); }, + resolvedSession + ? (request, token) => this._forkSession(sessionResource, resolvedSession!, request, token) + : undefined, initialProgress, () => { this._activeSessions.delete(sessionResource); @@ -724,7 +731,7 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC return; } - const turnId = generateUuid(); + const turnId = request.requestId; this._clientDispatchedTurnIds.add(turnId); const cleanUpTurnId = () => this._clientDispatchedTurnIds.delete(turnId); const attachments = this._convertVariablesToAttachments(request); @@ -751,6 +758,36 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC } } + // If the chat model has fewer previous requests than the protocol has + // turns, a checkpoint was restored or a message was edited. Dispatch + // session/truncated so the server drops the stale tail. + const chatModel = this._chatService.getSession(request.sessionResource); + const protocolState = this._clientState.getSessionState(session.toString()); + if (chatModel && protocolState && protocolState.turns.length > 0) { + // -2 since -1 will already be the current request + const previousRequestIndex = chatModel.getRequests().findIndex(i => i.id === request.requestId) - 1; + const previousRequest = previousRequestIndex >= 0 ? chatModel.getRequests()[previousRequestIndex] : undefined; + if (!previousRequest && protocolState.turns.length > 0) { + const truncateAction: ISessionTruncatedAction = { + type: ActionType.SessionTruncated, + session: session.toString(), + }; + const truncateSeq = this._clientState.applyOptimistic(truncateAction); + this._config.connection.dispatchAction(truncateAction, this._clientState.clientId, truncateSeq); + } else { + const seenAtIndex = protocolState.turns.findIndex(t => t.id === previousRequest!.id); + if (seenAtIndex !== -1 && seenAtIndex < protocolState.turns.length - 1) { + const truncateAction: ISessionTruncatedAction = { + type: ActionType.SessionTruncated, + session: session.toString(), + turnId: previousRequest!.id, + }; + const truncateSeq = this._clientState.applyOptimistic(truncateAction); + this._config.connection.dispatchAction(truncateAction, this._clientState.clientId, truncateSeq); + } + } + } + // Dispatch session/turnStarted — the server will call sendMessage on // the provider as a side effect. const turnAction = { @@ -765,6 +802,12 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC const clientSeq = this._clientState.applyOptimistic(turnAction); this._config.connection.dispatchAction(turnAction, this._clientState.clientId, clientSeq); + // Ensure the editing session records a sentinel checkpoint for this + // request so it appears in requestDisablement even if the turn + // produces no file edits. + this._ensureEditingSession(request.sessionResource) + ?.ensureRequestCheckpoint(request.requestId); + // Track live ChatToolInvocation objects for this turn const activeToolInvocations = new Map(); @@ -1211,14 +1254,68 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC return AgentSession.uri(this._config.provider, rawId); } + /** + * Forks a session at the given request point by creating a new backend + * session with the `fork` parameter. Returns an {@link IChatSessionItem} + * pointing to the newly created session. + */ + private async _forkSession( + sessionResource: URI, + backendSession: URI, + request: IChatSessionRequestHistoryItem | undefined, + token: CancellationToken, + ): Promise { + if (token.isCancellationRequested) { + throw new Error('Cancelled'); + } + + // Determine the turn index to fork at. If a specific request is + // provided, find its position in the protocol state's turn list. + // Otherwise fork the entire session. + const protocolState = this._clientState.getSessionState(backendSession.toString()); + let turnIndex: number | undefined; + if (request) { + turnIndex = protocolState?.turns.findIndex(t => t.id === request.id); + if (turnIndex === undefined || turnIndex < 0) { + throw new Error(`Cannot fork: turn for request ${request.id} not found in protocol state`); + } + } else if (protocolState && protocolState.turns.length > 0) { + turnIndex = protocolState.turns.length - 1; + } + + if (turnIndex === undefined) { + throw new Error('Cannot fork: no turns to fork from'); + } + + const chatModel = this._chatService.getSession(sessionResource); + + const forkedSession = await this._createAndSubscribe(sessionResource, undefined, { + session: backendSession, + turnIndex, + }); + + const forkedRawId = AgentSession.id(forkedSession); + const forkedResource = URI.from({ scheme: this._config.sessionType, path: `/${forkedRawId}` }); + const now = Date.now(); + + return { + resource: forkedResource, + label: chatModel?.title + ? localize('chat.forked.title', "Forked: {0}", chatModel.title) + : `Forked session`, + iconPath: getAgentHostIcon(this._productService), + timing: { created: now, lastRequestStarted: now, lastRequestEnded: now }, + }; + } + /** Creates a new backend session and subscribes to its state. */ - private async _createAndSubscribe(sessionResource: URI, modelId?: string): Promise { + private async _createAndSubscribe(sessionResource: URI, modelId?: string, fork?: { session: URI; turnIndex: number }): Promise { const rawModelId = this._extractRawModelId(modelId); const resourceKey = sessionResource.path.substring(1); const workingDirectory = this._config.resolveWorkingDirectory?.(resourceKey) ?? this._workspaceContextService.getWorkspace().folders[0]?.uri; - this._logService.trace(`[AgentHost] Creating new session, model=${rawModelId ?? '(default)'}, provider=${this._config.provider}`); + this._logService.trace(`[AgentHost] Creating new session, model=${rawModelId ?? '(default)'}, provider=${this._config.provider}${fork ? `, fork from ${fork.session.toString()} at index ${fork.turnIndex}` : ''}`); let session: URI; try { @@ -1226,6 +1323,7 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC model: rawModelId, provider: this._config.provider, workingDirectory, + fork, }); } catch (err) { // If authentication is required, try to resolve it and retry once @@ -1237,6 +1335,7 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC model: rawModelId, provider: this._config.provider, workingDirectory, + fork, }); } else { throw new Error(localize('agentHost.authRequired', "Authentication is required to start a session. Please sign in and try again.")); diff --git a/src/vs/workbench/contrib/chat/browser/chatEditing/chatEditingActions.ts b/src/vs/workbench/contrib/chat/browser/chatEditing/chatEditingActions.ts index 709c7664063..b4599937751 100644 --- a/src/vs/workbench/contrib/chat/browser/chatEditing/chatEditingActions.ts +++ b/src/vs/workbench/contrib/chat/browser/chatEditing/chatEditingActions.ts @@ -30,7 +30,7 @@ import { IEditorService } from '../../../../services/editor/common/editorService import { IAgentSessionsService } from '../agentSessions/agentSessionsService.js'; import { IChatRequestVariableEntry, isImplicitVariableEntry, isPromptFileVariableEntry, isPromptTextVariableEntry, isStringVariableEntry, isWorkspaceVariableEntry } from '../../common/attachments/chatVariableEntries.js'; import { isChatViewTitleActionContext } from '../../common/actions/chatActions.js'; -import { ChatContextKeys } from '../../common/actions/chatContextKeys.js'; +import { ChatContextKeyExprs, ChatContextKeys } from '../../common/actions/chatContextKeys.js'; import { applyingChatEditsFailedContextKey, CHAT_EDITING_MULTI_DIFF_SOURCE_RESOLVER_SCHEME, chatEditingResourceContextKey, chatEditingWidgetFileStateContextKey, decidedChatEditingResourceContextKey, hasAppliedChatEditsContextKey, hasUndecidedChatEditingResourceContextKey, IChatEditingService, IChatEditingSession, ModifiedFileEntryState } from '../../common/editing/chatEditingService.js'; import { IChatService } from '../../common/chatService/chatService.js'; import { isChatTreeItem, isRequestVM, isResponseVM } from '../../common/model/chatViewModel.js'; @@ -533,7 +533,7 @@ registerAction2(class RemoveAction extends Action2 { id: MenuId.ChatMessageTitle, group: 'navigation', order: 2, - when: ContextKeyExpr.and(ContextKeyExpr.equals(`config.${ChatConfiguration.EditRequests}`, 'input').negate(), ContextKeyExpr.equals(`config.${ChatConfiguration.CheckpointsEnabled}`, false), ChatContextKeys.lockedToCodingAgent.negate()), + when: ContextKeyExpr.and(ContextKeyExpr.equals(`config.${ChatConfiguration.EditRequests}`, 'input').negate(), ContextKeyExpr.equals(`config.${ChatConfiguration.CheckpointsEnabled}`, false), ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession)), } ] }); @@ -586,7 +586,7 @@ registerAction2(class RestoreCheckpointAction extends Action2 { id: MenuId.ChatMessageCheckpoint, group: 'navigation', order: 2, - when: ContextKeyExpr.and(ChatContextKeys.isRequest, ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeys.isFirstRequest.negate()) + when: ContextKeyExpr.and(ChatContextKeys.isRequest, ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession), ChatContextKeys.isFirstRequest.negate()) } ] }); @@ -633,7 +633,7 @@ registerAction2(class StartOverAction extends Action2 { id: MenuId.ChatMessageCheckpoint, group: 'navigation', order: 2, - when: ContextKeyExpr.and(ChatContextKeys.isRequest, ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeys.isFirstRequest) + when: ContextKeyExpr.and(ChatContextKeys.isRequest, ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession), ChatContextKeys.isFirstRequest) } ] }); @@ -667,7 +667,7 @@ registerAction2(class RestoreLastCheckpoint extends Action2 { precondition: ContextKeyExpr.and( ChatContextKeys.inChatSession, ContextKeyExpr.equals(`config.${ChatConfiguration.CheckpointsEnabled}`, true), - ChatContextKeys.lockedToCodingAgent.negate() + ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession) ) }); } diff --git a/src/vs/workbench/contrib/chat/browser/widget/chatWidget.ts b/src/vs/workbench/contrib/chat/browser/widget/chatWidget.ts index 7689707940f..1495a7924a6 100644 --- a/src/vs/workbench/contrib/chat/browser/widget/chatWidget.ts +++ b/src/vs/workbench/contrib/chat/browser/widget/chatWidget.ts @@ -199,6 +199,7 @@ const supportsAllAttachments: Required = { supportsTerminalAttachments: true, supportsPromptAttachments: true, supportsHandOffs: true, + supportsCheckpoints: true, }; const DISCLAIMER = localize('chatDisclaimer', "AI responses may be inaccurate"); @@ -2159,7 +2160,8 @@ export class ChatWidget extends Disposable implements IChatWidget { // Update capabilities for the locked agent const agent = this.chatAgentService.getAgent(agentId); this._updateAgentCapabilitiesContextKeys(agent); - this.listWidget?.updateRendererOptions({ restorable: false, editable: false, noFooter: true, progressMessageAtBottomOfResponse: true }); + const supportsCheckpoints = this._attachmentCapabilities.supportsCheckpoints ?? false; + this.listWidget?.updateRendererOptions({ restorable: supportsCheckpoints, editable: supportsCheckpoints, noFooter: true, progressMessageAtBottomOfResponse: true }); if (this.visible) { this.listWidget?.rerender(); } @@ -2337,8 +2339,17 @@ export class ChatWidget extends Disposable implements IChatWidget { options.queue = undefined; } + // For agents that support checkpoints, preserve the checkpoint + // through finishedEditing so blocked requests are removed below + // and the agent host can dispatch a protocol truncation action. + const preserveCheckpoint = this._lockedAgent && !!this._attachmentCapabilities.supportsCheckpoints; + if (preserveCheckpoint) { + this.recentlyRestoredCheckpoint = true; + } this.finishedEditing(true); - this.viewModel.model?.setCheckpoint(undefined); + if (!preserveCheckpoint) { + this.viewModel.model?.setCheckpoint(undefined); + } } const model = this.viewModel.model; @@ -2408,6 +2419,7 @@ export class ChatWidget extends Disposable implements IChatWidget { this.chatService.removeRequest(this.viewModel.sessionResource, request.id); } } + this.viewModel.model.setCheckpoint(undefined); } // Expand directory attachments: extract images as binary entries const resolvedImageVariables = await this._resolveDirectoryImageAttachments(requestInputs.attachedContext.asArray()); diff --git a/src/vs/workbench/contrib/chat/common/actions/chatContextKeys.ts b/src/vs/workbench/contrib/chat/common/actions/chatContextKeys.ts index 146941eba61..e55586dde7f 100644 --- a/src/vs/workbench/contrib/chat/common/actions/chatContextKeys.ts +++ b/src/vs/workbench/contrib/chat/common/actions/chatContextKeys.ts @@ -161,6 +161,15 @@ export namespace ChatContextKeyExprs { ChatContextKeys.chatModeKind.isEqualTo(ChatModeKind.Agent), ); + /** + * True when the locked coding agent is an agent host session (agent-host-* or remote-*). + * These sessions use AgentHostEditingSession which supports checkpoint-based undo/redo. + */ + export const isAgentHostSession = ContextKeyExpr.or( + ContextKeyExpr.regex(ChatContextKeys.lockedCodingAgentId.key, /^agent-host-/), + ContextKeyExpr.regex(ChatContextKeys.lockedCodingAgentId.key, /^remote-/), + ); + /** * Context expression that indicates when the welcome/setup view should be shown */ diff --git a/src/vs/workbench/contrib/chat/common/participants/chatAgents.ts b/src/vs/workbench/contrib/chat/common/participants/chatAgents.ts index 80728aae75a..f65d4011116 100644 --- a/src/vs/workbench/contrib/chat/common/participants/chatAgents.ts +++ b/src/vs/workbench/contrib/chat/common/participants/chatAgents.ts @@ -50,6 +50,7 @@ export interface IChatAgentAttachmentCapabilities { supportsTerminalAttachments?: boolean; supportsPromptAttachments?: boolean; supportsHandOffs?: boolean; + supportsCheckpoints?: boolean; } export interface IChatAgentData { diff --git a/src/vs/workbench/contrib/chat/test/browser/agentHost/agentHostEditingSession.test.ts b/src/vs/workbench/contrib/chat/test/browser/agentHost/agentHostEditingSession.test.ts index 956bbb28710..b4a0ffc155c 100644 --- a/src/vs/workbench/contrib/chat/test/browser/agentHost/agentHostEditingSession.test.ts +++ b/src/vs/workbench/contrib/chat/test/browser/agentHost/agentHostEditingSession.test.ts @@ -675,4 +675,326 @@ suite('AgentHostEditingSession', () => { assert.strictEqual(diff!.isFinal, true); }); }); + + suite('requestDisablement', () => { + test('returns empty when at the latest checkpoint', () => { + const session = createSession(store, new Map()); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: 'b', + afterURI: 'a', + })); + + assert.deepStrictEqual(session.requestDisablement.get(), []); + }); + + test('disables requests after undo', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + const session = createSession(store, contentMap); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + session.addToolCallEdits('req-2', makeToolCall({ + toolCallId: 'tc-2', + filePath: '/workspace/file.ts', + beforeURI: 'content://after-1', + afterURI: 'content://after-2', + })); + + // Undo the last tool call — req-2's sentinel + tool call are now after the cursor + await session.undoInteraction(); + + const disabled = session.requestDisablement.get(); + assert.strictEqual(disabled.length, 1); + assert.strictEqual(disabled[0].requestId, 'req-2'); + // The first checkpoint after current is the sentinel (undoStopId === undefined) + assert.strictEqual(disabled[0].afterUndoStop, undefined); + }); + + test('clears disablement after redo', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + contentMap.set(toAgentHostUri(URI.parse('content://after-2'), 'local').toString(), 'after-2'); + const session = createSession(store, contentMap); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + session.addToolCallEdits('req-2', makeToolCall({ + toolCallId: 'tc-2', + filePath: '/workspace/file.ts', + beforeURI: 'content://after-1', + afterURI: 'content://after-2', + })); + + await session.undoInteraction(); + assert.strictEqual(session.requestDisablement.get().length, 1); + + await session.redoInteraction(); + assert.deepStrictEqual(session.requestDisablement.get(), []); + }); + }); + + suite('restoreSnapshot', () => { + test('restoreSnapshot with undefined stopId navigates before the request', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + contentMap.set(toAgentHostUri(URI.parse('content://after-1'), 'local').toString(), 'after-1'); + const session = createSession(store, contentMap); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + session.addToolCallEdits('req-2', makeToolCall({ + toolCallId: 'tc-2', + filePath: '/workspace/file.ts', + beforeURI: 'content://after-1', + afterURI: 'content://after-2', + })); + + // Restore to before req-2 — should undo req-2's edits + const writes: IWriteFileParams[] = []; + store.add(session.onDidRequestFileWrite(p => writes.push(p))); + + await session.restoreSnapshot('req-2', undefined); + + // req-2 has a tool checkpoint whose before-content should be written + assert.ok(writes.length > 0); + // Entries should only show req-1's edits + assert.strictEqual(session.entries.get().length, 1); + assert.strictEqual(session.entries.get()[0].lastModifyingRequestId, 'req-1'); + // req-2 should be disabled + assert.strictEqual(session.requestDisablement.get().length, 1); + assert.strictEqual(session.requestDisablement.get()[0].requestId, 'req-2'); + }); + + test('restoreSnapshot with stopId navigates to that checkpoint', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + contentMap.set(toAgentHostUri(URI.parse('content://after-1'), 'local').toString(), 'after-1'); + const session = createSession(store, contentMap); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-2', + filePath: '/workspace/file.ts', + beforeURI: 'content://after-1', + afterURI: 'content://after-2', + })); + + // Restore to specific tool call tc-1 within req-1 + await session.restoreSnapshot('req-1', 'tc-1'); + + // Should keep tc-1 but not tc-2 + assert.strictEqual(session.canUndo.get(), true); + assert.strictEqual(session.canRedo.get(), true); + }); + + test('restoreSnapshot for first request navigates to before all edits', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + const session = createSession(store, contentMap); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + await session.restoreSnapshot('req-1', undefined); + + // No entries visible, all disabled + assert.deepStrictEqual(session.entries.get(), []); + assert.strictEqual(session.canUndo.get(), false); + assert.strictEqual(session.requestDisablement.get().length, 1); + assert.strictEqual(session.requestDisablement.get()[0].requestId, 'req-1'); + }); + }); + + suite('undo branch (splice stale checkpoints)', () => { + test('new edits after undo remove stale checkpoints', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + const session = createSession(store, contentMap); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + session.addToolCallEdits('req-2', makeToolCall({ + toolCallId: 'tc-2', + filePath: '/workspace/file.ts', + beforeURI: 'content://after-1', + afterURI: 'content://after-2', + })); + + // Undo last checkpoint + await session.undoInteraction(); + + // Now add a new edit — should splice away req-2's sentinel + checkpoint + session.addToolCallEdits('req-3', makeToolCall({ + toolCallId: 'tc-3', + filePath: '/workspace/file.ts', + beforeURI: 'content://after-1', + afterURI: 'content://after-3', + })); + + // req-2 should be gone, req-3 should be present + assert.strictEqual(session.canRedo.get(), false); + assert.strictEqual(session.requestDisablement.get().length, 0); + assert.strictEqual(session.hasEditsInRequest('req-2'), false); + assert.strictEqual(session.hasEditsInRequest('req-3'), true); + }); + + test('sentinel checkpoint is preserved after splice for new request', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + const session = createSession(store, contentMap); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + // Undo + await session.undoInteraction(); + + // Add new request — sentinel should survive the splice + session.addToolCallEdits('req-2', makeToolCall({ + toolCallId: 'tc-2', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-2', + })); + + // Undo once (tc-2), then check that req-2 is disabled via the sentinel + await session.undoInteraction(); + const disabled = session.requestDisablement.get(); + assert.strictEqual(disabled.length, 1); + assert.strictEqual(disabled[0].requestId, 'req-2'); + assert.strictEqual(disabled[0].afterUndoStop, undefined); + }); + }); + + suite('ensureRequestCheckpoint', () => { + test('creates sentinel for request without tool calls', () => { + const session = createSession(store, new Map()); + + session.ensureRequestCheckpoint('req-1'); + + // No entries visible (sentinel has no edits) + assert.deepStrictEqual(session.entries.get(), []); + // hasEditsInRequest returns true because the sentinel exists + assert.strictEqual(session.hasEditsInRequest('req-1'), true); + }); + + test('request without edits appears in requestDisablement after undo', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + const session = createSession(store, contentMap); + + // req-1 has file edits + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + // req-2 has no file edits, only a sentinel + session.ensureRequestCheckpoint('req-2'); + + // Undo tc-1 — both req-1 (partially) and req-2 should be disabled + await session.undoInteraction(); + + const disabled = session.requestDisablement.get(); + const disabledIds = disabled.map(d => d.requestId); + assert.ok(disabledIds.includes('req-2'), 'req-2 should be disabled'); + }); + + test('restoreSnapshot works for request with only sentinel', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + const session = createSession(store, contentMap); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + session.ensureRequestCheckpoint('req-2'); + + // Restore to before req-2 — should keep req-1's edits but disable req-2 + await session.restoreSnapshot('req-2', undefined); + + assert.strictEqual(session.entries.get().length, 1); + assert.strictEqual(session.entries.get()[0].lastModifyingRequestId, 'req-1'); + assert.strictEqual(session.requestDisablement.get().length, 1); + assert.strictEqual(session.requestDisablement.get()[0].requestId, 'req-2'); + }); + + test('idempotent — calling twice does not duplicate', () => { + const session = createSession(store, new Map()); + + session.ensureRequestCheckpoint('req-1'); + session.ensureRequestCheckpoint('req-1'); + + // Should still work — only one sentinel + assert.strictEqual(session.hasEditsInRequest('req-1'), true); + }); + + test('splices stale checkpoints when called after restore', async () => { + const contentMap = new Map(); + contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before'); + const session = createSession(store, contentMap); + + session.addToolCallEdits('req-1', makeToolCall({ + toolCallId: 'tc-1', + filePath: '/workspace/file.ts', + beforeURI: URI.file('/workspace/file.ts').toString(), + afterURI: 'content://after-1', + })); + + session.ensureRequestCheckpoint('req-2'); + + // Undo to before req-2's sentinel + await session.undoInteraction(); + + // New request should splice away req-2 + session.ensureRequestCheckpoint('req-3'); + + assert.strictEqual(session.hasEditsInRequest('req-2'), false); + assert.strictEqual(session.hasEditsInRequest('req-3'), true); + }); + }); });