agentHost: support checkpointing and forking

This enables restore/undo checkpoint for agent host sessions. It also
does some initial work for forking although this is not yet fully
implemented.

The copilot SDK does not actually support these yet, so to do these we
shut down the session, rewrite copilot SDK's disk storage, and start it
back up again. It actually works. We'll need to make sure it works when
we upgrade the SDK, but I don't expect it to break terribly, as the
Copilot SDK folks must already be backwards-compatible to arbitrary old
SDK versions that exist on the user's device, and we'd essentially just
be an 'old SDK' with some dependency on internals. Of course that should
all be swapped out when they eventually add proper support for it.

I just flagged the specific scheme of agent host sessions thus far while
developing, but will clean up prior to merging.
This commit is contained in:
Connor Peet
2026-03-31 16:02:33 -07:00
parent 9ee53b40bd
commit 8fe8366dac
24 changed files with 2024 additions and 63 deletions

View File

@@ -101,6 +101,8 @@ export interface IAgentCreateSessionConfig {
readonly model?: string;
readonly session?: URI;
readonly workingDirectory?: URI;
/** Fork from an existing session at a specific turn index. */
readonly fork?: { readonly session: URI; readonly turnIndex: number };
}
/** Serializable attachment passed alongside a message to the agent host. */
@@ -364,6 +366,23 @@ export interface IAgent {
*/
authenticate(resource: string, token: string): Promise<boolean>;
/**
* Truncate a session's history. If `turnIndex` is provided (0-based), keeps
* turns up to and including that turn. If omitted, all turns are removed.
* Optional — not all providers support truncation.
*/
truncateSession?(session: URI, turnIndex?: number): Promise<void>;
/**
* Fork a session at a specific turn, creating a new session on disk
* with the source session's history up to and including the specified turn.
* Optional — not all providers support forking.
*
* @param turnIndex 0-based turn index to fork at.
* @returns The new session's raw ID.
*/
forkSession?(sourceSession: URI, newSessionId: string, turnIndex: number): Promise<void>;
/** Gracefully shut down all sessions. */
shutdown(): Promise<void>;

View File

@@ -9,7 +9,7 @@
// Generated from types/actions.ts — do not edit
// Run `npm run generate` to regenerate.
import { ActionType, type IStateAction, type IRootAgentsChangedAction, type IRootActiveSessionsChangedAction, type ISessionReadyAction, type ISessionCreationFailedAction, type ISessionTurnStartedAction, type ISessionDeltaAction, type ISessionResponsePartAction, type ISessionToolCallStartAction, type ISessionToolCallDeltaAction, type ISessionToolCallReadyAction, type ISessionToolCallConfirmedAction, type ISessionToolCallCompleteAction, type ISessionToolCallResultConfirmedAction, type ISessionTurnCompleteAction, type ISessionTurnCancelledAction, type ISessionErrorAction, type ISessionTitleChangedAction, type ISessionUsageAction, type ISessionReasoningAction, type ISessionModelChangedAction, type ISessionServerToolsChangedAction, type ISessionActiveClientChangedAction, type ISessionActiveClientToolsChangedAction, type ISessionPendingMessageSetAction, type ISessionPendingMessageRemovedAction, type ISessionQueuedMessagesReorderedAction, type ISessionCustomizationsChangedAction, type ISessionCustomizationToggledAction } from './actions.js';
import { ActionType, type IStateAction, type IRootAgentsChangedAction, type IRootActiveSessionsChangedAction, type ISessionReadyAction, type ISessionCreationFailedAction, type ISessionTurnStartedAction, type ISessionDeltaAction, type ISessionResponsePartAction, type ISessionToolCallStartAction, type ISessionToolCallDeltaAction, type ISessionToolCallReadyAction, type ISessionToolCallConfirmedAction, type ISessionToolCallCompleteAction, type ISessionToolCallResultConfirmedAction, type ISessionTurnCompleteAction, type ISessionTurnCancelledAction, type ISessionErrorAction, type ISessionTitleChangedAction, type ISessionUsageAction, type ISessionReasoningAction, type ISessionModelChangedAction, type ISessionServerToolsChangedAction, type ISessionActiveClientChangedAction, type ISessionActiveClientToolsChangedAction, type ISessionPendingMessageSetAction, type ISessionPendingMessageRemovedAction, type ISessionQueuedMessagesReorderedAction, type ISessionCustomizationsChangedAction, type ISessionCustomizationToggledAction, type ISessionTruncatedAction } from './actions.js';
// ─── Root vs Session Action Unions ───────────────────────────────────────────
@@ -48,6 +48,7 @@ export type ISessionAction =
| ISessionQueuedMessagesReorderedAction
| ISessionCustomizationsChangedAction
| ISessionCustomizationToggledAction
| ISessionTruncatedAction
;
/** Union of session actions that clients may dispatch. */
@@ -65,6 +66,7 @@ export type IClientSessionAction =
| ISessionPendingMessageRemovedAction
| ISessionQueuedMessagesReorderedAction
| ISessionCustomizationToggledAction
| ISessionTruncatedAction
;
/** Union of session actions that only the server may produce. */
@@ -119,4 +121,5 @@ export const IS_CLIENT_DISPATCHABLE: { readonly [K in IStateAction['type']]: boo
[ActionType.SessionQueuedMessagesReordered]: true,
[ActionType.SessionCustomizationsChanged]: false,
[ActionType.SessionCustomizationToggled]: true,
[ActionType.SessionTruncated]: true,
};

View File

@@ -45,6 +45,7 @@ export const enum ActionType {
SessionQueuedMessagesReordered = 'session/queuedMessagesReordered',
SessionCustomizationsChanged = 'session/customizationsChanged',
SessionCustomizationToggled = 'session/customizationToggled',
SessionTruncated = 'session/truncated',
}
// ─── Action Envelope ─────────────────────────────────────────────────────────
@@ -562,6 +563,31 @@ export interface ISessionCustomizationToggledAction {
enabled: boolean;
}
// ─── Truncation ──────────────────────────────────────────────────────────────
/**
* Truncates a session's history. If `turnId` is provided, all turns after that
* turn are removed and the specified turn is kept. If `turnId` is omitted, all
* turns are removed.
*
* If there is an active turn it is silently dropped and the session status
* returns to `idle`.
*
* Common use-case: truncate old data then dispatch a new
* `session/turnStarted` with an edited message.
*
* @category Session Actions
* @version 1
* @clientDispatchable
*/
export interface ISessionTruncatedAction {
type: ActionType.SessionTruncated;
/** Session URI */
session: URI;
/** Keep turns up to and including this turn. Omit to clear all turns. */
turnId?: string;
}
// ─── Pending Message Actions ─────────────────────────────────────────────────
/**
@@ -664,4 +690,5 @@ export type IStateAction =
| ISessionPendingMessageRemovedAction
| ISessionQueuedMessagesReorderedAction
| ISessionCustomizationsChangedAction
| ISessionCustomizationToggledAction;
| ISessionCustomizationToggledAction
| ISessionTruncatedAction;

View File

@@ -163,6 +163,20 @@ export interface ISubscribeResult {
* { "jsonrpc": "2.0", "id": 2, "error": { "code": -32003, "message": "Session already exists" } }
* ```
*/
/**
* Identifies a source session and turn to fork from.
*
* When provided in `createSession`, the server populates the new session with
* content from the source session up to and including the response of the
* specified turn.
*/
export interface ISessionForkSource {
/** URI of the existing session to fork from */
session: URI;
/** Turn ID in the source session; content up to and including this turn's response is copied */
turnId: string;
}
export interface ICreateSessionParams {
/** Session URI (client-chosen, e.g. `copilot:/<uuid>`) */
session: URI;
@@ -172,6 +186,11 @@ export interface ICreateSessionParams {
model?: string;
/** Working directory for the session */
workingDirectory?: URI;
/**
* Fork from an existing session. The new session is populated with content
* from the source session up to and including the specified turn's response.
*/
fork?: ISessionForkSource;
}
// ─── disposeSession ──────────────────────────────────────────────────────────

View File

@@ -485,6 +485,27 @@ export function sessionReducer(state: ISessionState, action: ISessionAction, log
return { ...state, customizations: updated };
}
// ── Truncation ────────────────────────────────────────────────────────
case ActionType.SessionTruncated: {
let turns: typeof state.turns;
if (action.turnId === undefined) {
turns = [];
} else {
const idx = state.turns.findIndex(t => t.id === action.turnId);
if (idx < 0) {
return state;
}
turns = state.turns.slice(0, idx + 1);
}
return {
...state,
turns,
activeTurn: undefined,
summary: { ...state.summary, status: SessionStatus.Idle, modifiedAt: Date.now() },
};
}
// ── Pending Messages ──────────────────────────────────────────────────
case ActionType.SessionPendingMessageSet: {

View File

@@ -52,6 +52,7 @@ export const ACTION_INTRODUCED_IN: { readonly [K in IStateAction['type']]: numbe
[ActionType.SessionQueuedMessagesReordered]: 1,
[ActionType.SessionCustomizationsChanged]: 1,
[ActionType.SessionCustomizationToggled]: 1,
[ActionType.SessionTruncated]: 1,
};
/**

View File

@@ -173,17 +173,41 @@ export class AgentService extends Disposable implements IAgentService {
this._sessionToProvider.set(session.toString(), provider.id);
this._logService.trace(`[AgentService] createSession returned: ${session.toString()}`);
// Create state in the state manager
const summary: ISessionSummary = {
resource: session.toString(),
provider: provider.id,
title: 'New Session',
status: SessionStatus.Idle,
createdAt: Date.now(),
modifiedAt: Date.now(),
workingDirectory: config?.workingDirectory?.toString(),
};
this._stateManager.createSession(summary);
// When forking, populate the new session's protocol state with
// the source session's turns so the client sees the forked history.
if (config?.fork) {
const sourceState = this._stateManager.getSessionState(config.fork.session.toString());
let sourceTurns: ITurn[] = [];
if (sourceState) {
const forkIdx = sourceState.turns.findIndex(t => t.id === config.fork!.turnId);
if (forkIdx >= 0) {
sourceTurns = sourceState.turns.slice(0, forkIdx + 1);
}
}
const summary: ISessionSummary = {
resource: session.toString(),
provider: provider.id,
title: sourceState?.summary.title ?? 'Forked Session',
status: SessionStatus.Idle,
createdAt: Date.now(),
modifiedAt: Date.now(),
workingDirectory: config.workingDirectory?.toString(),
};
this._stateManager.restoreSession(summary, sourceTurns);
} else {
// Create empty state for new sessions
const summary: ISessionSummary = {
resource: session.toString(),
provider: provider.id,
title: 'New Session',
status: SessionStatus.Idle,
createdAt: Date.now(),
modifiedAt: Date.now(),
workingDirectory: config?.workingDirectory?.toString(),
};
this._stateManager.createSession(summary);
}
this._stateManager.dispatchServerAction({ type: ActionType.SessionReady, session: session.toString() });
return session;

View File

@@ -252,6 +252,24 @@ export class AgentSideEffects extends Disposable {
this._syncPendingMessages(action.session);
break;
}
case ActionType.SessionTruncated: {
const agent = this._options.getAgent(action.session);
// Resolve the protocol turnId to a 0-based index using the
// state manager's turn list. The reducer has already applied
// the truncation, so we look at the pre-truncation state via
// the turnId position.
let turnIndex: number | undefined;
if (action.turnId !== undefined) {
const state = this._stateManager.getSessionState(action.session);
// After the reducer, the turns array is already truncated.
// The kept turns include the target, so its index = length - 1.
turnIndex = state ? state.turns.length - 1 : undefined;
}
agent?.truncateSession?.(URI.parse(action.session), turnIndex).catch(err => {
this._logService.error('[AgentSideEffects] truncateSession failed', err);
});
break;
}
}
}

View File

@@ -5,6 +5,7 @@
import { CopilotClient } from '@github/copilot-sdk';
import { rgPath } from '@vscode/ripgrep';
import { SequencerByKey } from '../../../../base/common/async.js';
import { Emitter } from '../../../../base/common/event.js';
import { Disposable, DisposableMap } from '../../../../base/common/lifecycle.js';
import { FileAccess } from '../../../../base/common/network.js';
@@ -18,6 +19,7 @@ import { AgentSession, IAgent, IAgentAttachment, IAgentCreateSessionConfig, IAge
import { type IPendingMessage, type PolicyState } from '../../common/state/sessionState.js';
import { CopilotAgentSession, SessionWrapperFactory } from './copilotAgentSession.js';
import { CopilotSessionWrapper } from './copilotSessionWrapper.js';
import { forkCopilotSessionOnDisk, getCopilotDataDir, truncateCopilotSessionOnDisk } from './copilotAgentForking.js';
/**
* Agent provider backed by the Copilot SDK {@link CopilotClient}.
@@ -32,6 +34,7 @@ export class CopilotAgent extends Disposable implements IAgent {
private _clientStarting: Promise<CopilotClient> | undefined;
private _githubToken: string | undefined;
private readonly _sessions = this._register(new DisposableMap<string, CopilotAgentSession>());
private readonly _sessionSequencer = new SequencerByKey<string>();
constructor(
@ILogService private readonly _logService: ILogService,
@@ -181,10 +184,38 @@ export class CopilotAgent extends Disposable implements IAgent {
this._logService.info(`[Copilot] Creating session... ${config?.model ? `model=${config.model}` : ''}`);
const client = await this._ensureClient();
// When forking, we manipulate the CLI's on-disk data and then resume
// instead of creating a fresh session via the SDK.
if (config?.fork) {
const sourceSessionId = AgentSession.id(config.fork.session);
const newSessionId = config.session ? AgentSession.id(config.session) : generateUuid();
// Serialize against the source session to prevent concurrent
// modifications while we read its on-disk data.
return this._sessionSequencer.queue(sourceSessionId, async () => {
this._logService.info(`[Copilot] Forking session ${sourceSessionId} at index ${config.fork!.turnIndex}${newSessionId}`);
// Ensure the source session is loaded so on-disk data is available
if (!this._sessions.has(sourceSessionId)) {
await this._resumeSession(sourceSessionId);
}
const copilotDataDir = getCopilotDataDir();
await forkCopilotSessionOnDisk(copilotDataDir, sourceSessionId, newSessionId, config.fork!.turnIndex);
// Resume the forked session so the SDK loads the forked history
const agentSession = await this._resumeSession(newSessionId);
const session = agentSession.sessionUri;
this._logService.info(`[Copilot] Forked session created: ${session.toString()}`);
return session;
});
}
const sessionId = config?.session ? AgentSession.id(config.session) : generateUuid();
const factory: SessionWrapperFactory = async callbacks => {
const raw = await client.createSession({
model: config?.model,
sessionId: config?.session ? AgentSession.id(config.session) : undefined,
sessionId,
streaming: true,
workingDirectory: config?.workingDirectory?.fsPath,
onPermissionRequest: callbacks.onPermissionRequest,
@@ -193,7 +224,7 @@ export class CopilotAgent extends Disposable implements IAgent {
return new CopilotSessionWrapper(raw);
};
const agentSession = this._createAgentSession(factory, config?.workingDirectory, config?.session ? AgentSession.id(config.session) : undefined);
const agentSession = this._createAgentSession(factory, config?.workingDirectory, sessionId);
await agentSession.initializeSession();
const session = agentSession.sessionUri;
@@ -203,8 +234,10 @@ export class CopilotAgent extends Disposable implements IAgent {
async sendMessage(session: URI, prompt: string, attachments?: IAgentAttachment[]): Promise<void> {
const sessionId = AgentSession.id(session);
const entry = this._sessions.get(sessionId) ?? await this._resumeSession(sessionId);
await entry.send(prompt, attachments);
await this._sessionSequencer.queue(sessionId, async () => {
const entry = this._sessions.get(sessionId) ?? await this._resumeSession(sessionId);
await entry.send(prompt, attachments);
});
}
setPendingMessages(session: URI, steeringMessage: IPendingMessage | undefined, _queuedMessages: readonly IPendingMessage[]): void {
@@ -236,15 +269,57 @@ export class CopilotAgent extends Disposable implements IAgent {
async disposeSession(session: URI): Promise<void> {
const sessionId = AgentSession.id(session);
this._sessions.deleteAndDispose(sessionId);
await this._sessionSequencer.queue(sessionId, async () => {
this._sessions.deleteAndDispose(sessionId);
});
}
async abortSession(session: URI): Promise<void> {
const sessionId = AgentSession.id(session);
const entry = this._sessions.get(sessionId);
if (entry) {
await entry.abort();
}
await this._sessionSequencer.queue(sessionId, async () => {
const entry = this._sessions.get(sessionId);
if (entry) {
await entry.abort();
}
});
}
async truncateSession(session: URI, turnIndex?: number): Promise<void> {
const sessionId = AgentSession.id(session);
await this._sessionSequencer.queue(sessionId, async () => {
this._logService.info(`[Copilot:${sessionId}] Truncating session${turnIndex !== undefined ? ` at index ${turnIndex}` : ' (all turns)'}`);
const keepUpToTurnIndex = turnIndex ?? -1;
// Destroy the SDK session first and wait for cleanup to complete,
// ensuring on-disk data (events.jsonl, locks) is released before
// we modify it. Then dispose the wrapper.
const entry = this._sessions.get(sessionId);
if (entry) {
await entry.destroySession();
}
this._sessions.deleteAndDispose(sessionId);
if (keepUpToTurnIndex >= 0) {
const copilotDataDir = getCopilotDataDir();
await truncateCopilotSessionOnDisk(copilotDataDir, sessionId, keepUpToTurnIndex);
}
// Resume the session from the modified on-disk data
await this._resumeSession(sessionId);
this._logService.info(`[Copilot:${sessionId}] Session truncated and resumed`);
});
}
async forkSession(sourceSession: URI, newSessionId: string, turnIndex: number): Promise<void> {
const sourceSessionId = AgentSession.id(sourceSession);
await this._sessionSequencer.queue(sourceSessionId, async () => {
this._logService.info(`[Copilot] Forking session ${sourceSessionId} at index ${turnIndex}${newSessionId}`);
const copilotDataDir = getCopilotDataDir();
await forkCopilotSessionOnDisk(copilotDataDir, sourceSessionId, newSessionId, turnIndex);
this._logService.info(`[Copilot] Forked session ${newSessionId} created on disk`);
});
}
async changeModel(session: URI, model: string): Promise<void> {
@@ -284,20 +359,19 @@ export class CopilotAgent extends Disposable implements IAgent {
* and returns it. The caller must call {@link CopilotAgentSession.initializeSession}
* to wire up the SDK session.
*/
private _createAgentSession(wrapperFactory: SessionWrapperFactory, workingDirectory: URI | undefined, sessionIdOverride?: string): CopilotAgentSession {
const rawId = sessionIdOverride ?? generateUuid();
const sessionUri = AgentSession.uri(this.id, rawId);
private _createAgentSession(wrapperFactory: SessionWrapperFactory, workingDirectory: URI | undefined, sessionId: string): CopilotAgentSession {
const sessionUri = AgentSession.uri(this.id, sessionId);
const agentSession = this._instantiationService.createInstance(
CopilotAgentSession,
sessionUri,
rawId,
sessionId,
workingDirectory,
this._onDidSessionProgress,
wrapperFactory,
);
this._sessions.set(rawId, agentSession);
this._sessions.set(sessionId, agentSession);
return agentSession;
}

View File

@@ -0,0 +1,565 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as fs from 'fs';
import * as os from 'os';
import type { Database } from '@vscode/sqlite3';
import { generateUuid } from '../../../../base/common/uuid.js';
import * as path from '../../../../base/common/path.js';
// ---- Types ------------------------------------------------------------------
/**
* A single event entry from a Copilot CLI `events.jsonl` file.
* The Copilot CLI stores session history as a newline-delimited JSON log
* where events form a linked list via `parentId`.
*/
export interface ICopilotEventLogEntry {
readonly type: string;
readonly data: Record<string, unknown>;
readonly id: string;
readonly timestamp: string;
readonly parentId: string | null;
}
// ---- Promise wrappers around callback-based @vscode/sqlite3 API -----------
function dbExec(db: Database, sql: string): Promise<void> {
return new Promise((resolve, reject) => {
db.exec(sql, err => err ? reject(err) : resolve());
});
}
function dbRun(db: Database, sql: string, params: unknown[]): Promise<void> {
return new Promise((resolve, reject) => {
db.run(sql, params, function (err: Error | null) {
if (err) {
return reject(err);
}
resolve();
});
});
}
function dbAll(db: Database, sql: string, params: unknown[]): Promise<Record<string, unknown>[]> {
return new Promise((resolve, reject) => {
db.all(sql, params, (err: Error | null, rows: Record<string, unknown>[]) => {
if (err) {
return reject(err);
}
resolve(rows);
});
});
}
function dbClose(db: Database): Promise<void> {
return new Promise((resolve, reject) => {
db.close(err => err ? reject(err) : resolve());
});
}
function dbOpen(dbPath: string): Promise<Database> {
return new Promise((resolve, reject) => {
import('@vscode/sqlite3').then(sqlite3 => {
const db = new sqlite3.default.Database(dbPath, (err: Error | null) => {
if (err) {
return reject(err);
}
resolve(db);
});
}, reject);
});
}
// ---- Pure functions (testable, no I/O) ------------------------------------
/**
* Parses a JSONL string into an array of event log entries.
*/
export function parseEventLog(content: string): ICopilotEventLogEntry[] {
const entries: ICopilotEventLogEntry[] = [];
for (const line of content.split('\n')) {
const trimmed = line.trim();
if (trimmed.length === 0) {
continue;
}
entries.push(JSON.parse(trimmed));
}
return entries;
}
/**
* Serializes an array of event log entries back into a JSONL string.
*/
export function serializeEventLog(entries: readonly ICopilotEventLogEntry[]): string {
return entries.map(e => JSON.stringify(e)).join('\n') + '\n';
}
/**
* Finds the index of the last event that belongs to the given turn (0-based).
*
* A "turn" corresponds to one `user.message` event and all subsequent events
* up to (and including) the `assistant.turn_end` that closes that interaction,
* or the `session.shutdown` that ends the session.
*
* @returns The inclusive index of the last event in the specified turn,
* or `-1` if the turn is not found.
*/
export function findTurnBoundaryInEventLog(entries: readonly ICopilotEventLogEntry[], turnIndex: number): number {
let userMessageCount = -1;
let lastEventForTurn = -1;
for (let i = 0; i < entries.length; i++) {
const entry = entries[i];
if (entry.type === 'user.message') {
userMessageCount++;
if (userMessageCount > turnIndex) {
// We've entered the next turn — stop
return lastEventForTurn;
}
}
if (userMessageCount === turnIndex) {
lastEventForTurn = i;
}
}
// If we scanned everything and the target turn was found, return its last event
return lastEventForTurn;
}
/**
* Builds a forked event log from the source session's events.
*
* - Keeps events up to and including the specified fork turn (0-based).
* - Rewrites `session.start` with the new session ID.
* - Generates fresh UUIDs for all events.
* - Re-chains `parentId` links via an old→new ID map.
* - Strips `session.shutdown` and `session.resume` lifecycle events.
*/
export function buildForkedEventLog(
entries: readonly ICopilotEventLogEntry[],
forkTurnIndex: number,
newSessionId: string,
): ICopilotEventLogEntry[] {
const boundary = findTurnBoundaryInEventLog(entries, forkTurnIndex);
if (boundary < 0) {
throw new Error(`Fork turn index ${forkTurnIndex} not found in event log`);
}
// Keep events up to boundary, filtering out lifecycle events
const kept = entries
.slice(0, boundary + 1)
.filter(e => e.type !== 'session.shutdown' && e.type !== 'session.resume');
// Build UUID remap and re-chain
const idMap = new Map<string, string>();
const result: ICopilotEventLogEntry[] = [];
for (const entry of kept) {
const newId = generateUuid();
idMap.set(entry.id, newId);
let data = entry.data;
if (entry.type === 'session.start') {
data = { ...data, sessionId: newSessionId };
}
const newParentId = entry.parentId !== null
? (idMap.get(entry.parentId) ?? null)
: null;
result.push({
type: entry.type,
data,
id: newId,
timestamp: entry.timestamp,
parentId: newParentId,
});
}
return result;
}
/**
* Builds a truncated event log from the source session's events.
*
* - Keeps events up to and including the specified turn (0-based).
* - Prepends a new `session.start` event using the original start data.
* - Re-chains `parentId` links for remaining events.
*/
export function buildTruncatedEventLog(
entries: readonly ICopilotEventLogEntry[],
keepUpToTurnIndex: number,
): ICopilotEventLogEntry[] {
const boundary = findTurnBoundaryInEventLog(entries, keepUpToTurnIndex);
if (boundary < 0) {
throw new Error(`Turn index ${keepUpToTurnIndex} not found in event log`);
}
// Find the original session.start for its metadata
const originalStart = entries.find(e => e.type === 'session.start');
if (!originalStart) {
throw new Error('No session.start event found in event log');
}
// Keep events from after session start up to boundary, stripping lifecycle events
const kept = entries
.slice(0, boundary + 1)
.filter(e => e.type !== 'session.start' && e.type !== 'session.shutdown' && e.type !== 'session.resume');
// Build new start event
const newStartId = generateUuid();
const newStart: ICopilotEventLogEntry = {
type: 'session.start',
data: { ...originalStart.data, startTime: new Date().toISOString() },
id: newStartId,
timestamp: new Date().toISOString(),
parentId: null,
};
// Re-chain: first remaining event points to the new start
const idMap = new Map<string, string>();
idMap.set(originalStart.id, newStartId);
const result: ICopilotEventLogEntry[] = [newStart];
let lastId = newStartId;
for (const entry of kept) {
const newId = generateUuid();
idMap.set(entry.id, newId);
const newParentId = entry.parentId !== null
? (idMap.get(entry.parentId) ?? lastId)
: lastId;
result.push({
type: entry.type,
data: entry.data,
id: newId,
timestamp: entry.timestamp,
parentId: newParentId,
});
lastId = newId;
}
return result;
}
/**
* Generates a `workspace.yaml` file content for a Copilot CLI session.
*/
export function buildWorkspaceYaml(sessionId: string, cwd: string, summary: string): string {
const now = new Date().toISOString();
return [
`id: ${sessionId}`,
`cwd: ${cwd}`,
`summary_count: 0`,
`created_at: ${now}`,
`updated_at: ${now}`,
`summary: ${summary}`,
'',
].join('\n');
}
// ---- SQLite operations (Copilot CLI session-store.db) ---------------------
/**
* Forks a session record in the Copilot CLI's `session-store.db`.
*
* Copies the source session's metadata, turns (up to `forkTurnIndex`),
* session files, search index entries, and checkpoints into a new session.
*/
export async function forkSessionInDb(
db: Database,
sourceSessionId: string,
newSessionId: string,
forkTurnIndex: number,
): Promise<void> {
await dbExec(db, 'PRAGMA foreign_keys = ON');
await dbExec(db, 'BEGIN TRANSACTION');
try {
const now = new Date().toISOString();
// Copy session row
await dbRun(db,
`INSERT INTO sessions (id, cwd, repository, branch, summary, created_at, updated_at, host_type)
SELECT ?, cwd, repository, branch, summary, ?, ?, host_type
FROM sessions WHERE id = ?`,
[newSessionId, now, now, sourceSessionId],
);
// Copy turns up to fork point (turn_index is 0-based)
await dbRun(db,
`INSERT INTO turns (session_id, turn_index, user_message, assistant_response, timestamp)
SELECT ?, turn_index, user_message, assistant_response, timestamp
FROM turns
WHERE session_id = ? AND turn_index <= ?`,
[newSessionId, sourceSessionId, forkTurnIndex],
);
// Copy session files that were first seen at or before the fork point
await dbRun(db,
`INSERT INTO session_files (session_id, file_path, tool_name, turn_index, first_seen_at)
SELECT ?, file_path, tool_name, turn_index, first_seen_at
FROM session_files
WHERE session_id = ? AND turn_index <= ?`,
[newSessionId, sourceSessionId, forkTurnIndex],
);
// Copy search index entries for kept turns
await dbRun(db,
`INSERT INTO search_index (content, session_id, source_type, source_id)
SELECT content, ?, source_type,
REPLACE(source_id, ?, ?)
FROM search_index
WHERE session_id = ?
AND source_type = 'turn'`,
[newSessionId, sourceSessionId, newSessionId, sourceSessionId],
);
// Copy checkpoints at or before the fork point
await dbRun(db,
`INSERT INTO checkpoints (session_id, checkpoint_number, title, overview, history, work_done, technical_details, important_files, next_steps, created_at)
SELECT ?, checkpoint_number, title, overview, history, work_done, technical_details, important_files, next_steps, created_at
FROM checkpoints
WHERE session_id = ?`,
[newSessionId, sourceSessionId],
);
await dbExec(db, 'COMMIT');
} catch (err) {
await dbExec(db, 'ROLLBACK');
throw err;
}
}
/**
* Truncates a session in the Copilot CLI's `session-store.db`.
*
* Removes all turns after `keepUpToTurnIndex` and updates session metadata.
*/
export async function truncateSessionInDb(
db: Database,
sessionId: string,
keepUpToTurnIndex: number,
): Promise<void> {
await dbExec(db, 'PRAGMA foreign_keys = ON');
await dbExec(db, 'BEGIN TRANSACTION');
try {
const now = new Date().toISOString();
// Delete turns after the truncation point
await dbRun(db,
`DELETE FROM turns WHERE session_id = ? AND turn_index > ?`,
[sessionId, keepUpToTurnIndex],
);
// Update session timestamp
await dbRun(db,
`UPDATE sessions SET updated_at = ? WHERE id = ?`,
[now, sessionId],
);
// Remove search index entries for removed turns
// source_id format is "<session_id>:turn:<turn_index>"
await dbAll(db,
`SELECT source_id FROM search_index
WHERE session_id = ? AND source_type = 'turn'`,
[sessionId],
).then(async rows => {
const prefix = `${sessionId}:turn:`;
for (const row of rows) {
const sourceId = row.source_id as string;
if (sourceId.startsWith(prefix)) {
const turnIdx = parseInt(sourceId.substring(prefix.length), 10);
if (!isNaN(turnIdx) && turnIdx > keepUpToTurnIndex) {
await dbRun(db,
`DELETE FROM search_index WHERE source_id = ? AND session_id = ?`,
[sourceId, sessionId],
);
}
}
}
});
await dbExec(db, 'COMMIT');
} catch (err) {
await dbExec(db, 'ROLLBACK');
throw err;
}
}
// ---- File system operations -----------------------------------------------
/**
* Resolves the Copilot CLI data directory.
* The Copilot CLI stores its data in `~/.copilot/` by default, or in the
* directory specified by `COPILOT_CONFIG_DIR`.
*/
export function getCopilotDataDir(): string {
return process.env['COPILOT_CONFIG_DIR'] ?? path.join(os.homedir(), '.copilot');
}
/**
* Forks a Copilot CLI session on disk.
*
* 1. Reads the source session's `events.jsonl`
* 2. Builds a forked event log
* 3. Creates the new session folder with all required files/directories
* 4. Updates the `session-store.db`
*
* @param copilotDataDir Path to the `.copilot` directory
* @param sourceSessionId UUID of the source session to fork from
* @param newSessionId UUID for the new forked session
* @param forkTurnIndex 0-based turn index to fork at (inclusive)
*/
export async function forkCopilotSessionOnDisk(
copilotDataDir: string,
sourceSessionId: string,
newSessionId: string,
forkTurnIndex: number,
): Promise<void> {
const sessionStateDir = path.join(copilotDataDir, 'session-state');
// Read source events
const sourceEventsPath = path.join(sessionStateDir, sourceSessionId, 'events.jsonl');
const sourceContent = await fs.promises.readFile(sourceEventsPath, 'utf-8');
const sourceEntries = parseEventLog(sourceContent);
// Build forked event log
const forkedEntries = buildForkedEventLog(sourceEntries, forkTurnIndex, newSessionId);
// Read source workspace.yaml for cwd/summary
let cwd = '';
let summary = '';
try {
const workspaceYamlPath = path.join(sessionStateDir, sourceSessionId, 'workspace.yaml');
const yamlContent = await fs.promises.readFile(workspaceYamlPath, 'utf-8');
const cwdMatch = yamlContent.match(/^cwd:\s*(.+)$/m);
const summaryMatch = yamlContent.match(/^summary:\s*(.+)$/m);
if (cwdMatch) {
cwd = cwdMatch[1].trim();
}
if (summaryMatch) {
summary = summaryMatch[1].trim();
}
} catch {
// Fall back to session.start data
const startEvent = sourceEntries.find(e => e.type === 'session.start');
if (startEvent) {
const ctx = startEvent.data.context as Record<string, string> | undefined;
cwd = ctx?.cwd ?? '';
}
}
// Create new session folder structure
const newSessionDir = path.join(sessionStateDir, newSessionId);
await fs.promises.mkdir(path.join(newSessionDir, 'checkpoints'), { recursive: true });
await fs.promises.mkdir(path.join(newSessionDir, 'files'), { recursive: true });
await fs.promises.mkdir(path.join(newSessionDir, 'research'), { recursive: true });
// Write events.jsonl
await fs.promises.writeFile(
path.join(newSessionDir, 'events.jsonl'),
serializeEventLog(forkedEntries),
'utf-8',
);
// Write workspace.yaml
await fs.promises.writeFile(
path.join(newSessionDir, 'workspace.yaml'),
buildWorkspaceYaml(newSessionId, cwd, summary),
'utf-8',
);
// Write empty vscode.metadata.json
await fs.promises.writeFile(
path.join(newSessionDir, 'vscode.metadata.json'),
'{}',
'utf-8',
);
// Write empty checkpoints index
await fs.promises.writeFile(
path.join(newSessionDir, 'checkpoints', 'index.md'),
'',
'utf-8',
);
// Update session-store.db
const dbPath = path.join(copilotDataDir, 'session-store.db');
const db = await dbOpen(dbPath);
try {
await forkSessionInDb(db, sourceSessionId, newSessionId, forkTurnIndex);
} finally {
await dbClose(db);
}
}
/**
* Truncates a Copilot CLI session on disk.
*
* 1. Reads the session's `events.jsonl`
* 2. Builds a truncated event log
* 3. Overwrites `events.jsonl` and updates `workspace.yaml`
* 4. Updates the `session-store.db`
*
* @param copilotDataDir Path to the `.copilot` directory
* @param sessionId UUID of the session to truncate
* @param keepUpToTurnIndex 0-based turn index to keep up to (inclusive)
*/
export async function truncateCopilotSessionOnDisk(
copilotDataDir: string,
sessionId: string,
keepUpToTurnIndex: number,
): Promise<void> {
const sessionStateDir = path.join(copilotDataDir, 'session-state');
const sessionDir = path.join(sessionStateDir, sessionId);
// Read and truncate events
const eventsPath = path.join(sessionDir, 'events.jsonl');
const content = await fs.promises.readFile(eventsPath, 'utf-8');
const entries = parseEventLog(content);
const truncatedEntries = buildTruncatedEventLog(entries, keepUpToTurnIndex);
// Overwrite events.jsonl
await fs.promises.writeFile(eventsPath, serializeEventLog(truncatedEntries), 'utf-8');
// Update workspace.yaml timestamp
try {
const yamlPath = path.join(sessionDir, 'workspace.yaml');
let yaml = await fs.promises.readFile(yamlPath, 'utf-8');
yaml = yaml.replace(/^updated_at:\s*.+$/m, `updated_at: ${new Date().toISOString()}`);
await fs.promises.writeFile(yamlPath, yaml, 'utf-8');
} catch {
// workspace.yaml may not exist (old format)
}
// Update session-store.db
const dbPath = path.join(copilotDataDir, 'session-store.db');
const db = await dbOpen(dbPath);
try {
await truncateSessionInDb(db, sessionId, keepUpToTurnIndex);
} finally {
await dbClose(db);
}
}
/**
* Maps a protocol turn ID to a 0-based turn index by finding the turn's
* position within the session's event log.
*
* The protocol state assigns arbitrary string IDs to turns, but the Copilot
* CLI's `events.jsonl` uses sequential `user.message` events. To bridge the
* two, we match turns by their position in the sequence.
*
* @returns The 0-based turn index, or `-1` if the turn ID is not found in the
* `turnIds` array.
*/
export function turnIdToIndex(turnIds: readonly string[], turnId: string): number {
return turnIds.indexOf(turnId);
}

View File

@@ -228,6 +228,16 @@ export class CopilotAgentSession extends Disposable {
await this._wrapper.session.abort();
}
/**
* Explicitly destroys the underlying SDK session and waits for cleanup
* to complete. Call this before {@link dispose} when you need to ensure
* the session's on-disk data is no longer locked (e.g. before
* truncation or fork operations that modify the session files).
*/
async destroySession(): Promise<void> {
await this._wrapper.session.destroy();
}
async setModel(model: string): Promise<void> {
this._logService.info(`[Copilot:${this.sessionId}] Changing model to: ${model}`);
await this._wrapper.session.setModel(model);

View File

@@ -310,12 +310,24 @@ export class ProtocolServerHandler extends Disposable {
},
createSession: async (_client, params) => {
let createdSession: URI;
// Resolve fork turnId to a 0-based index using the source session's
// turn list in the state manager.
let fork: { session: URI; turnIndex: number } | undefined;
if (params.fork) {
const sourceState = this._stateManager.getSessionState(params.fork.session);
const turnIndex = sourceState?.turns.findIndex(t => t.id === params.fork!.turnId) ?? -1;
if (turnIndex < 0) {
throw new ProtocolError(AHP_PROVIDER_NOT_FOUND, `Fork turn ID ${params.fork.turnId} not found in session ${params.fork.session}`);
}
fork = { session: URI.parse(params.fork.session), turnIndex };
}
try {
createdSession = await this._agentService.createSession({
provider: params.provider,
model: params.model,
workingDirectory: params.workingDirectory ? URI.parse(params.workingDirectory) : undefined,
session: URI.parse(params.session),
fork,
});
} catch (err) {
if (err instanceof ProtocolError) {

View File

@@ -0,0 +1,650 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import assert from 'assert';
import { ensureNoDisposablesAreLeakedInTestSuite } from '../../../../base/test/common/utils.js';
import {
parseEventLog,
serializeEventLog,
findTurnBoundaryInEventLog,
buildForkedEventLog,
buildTruncatedEventLog,
buildWorkspaceYaml,
forkSessionInDb,
truncateSessionInDb,
type ICopilotEventLogEntry,
} from '../../node/copilot/copilotAgentForking.js';
// ---- Test helpers -----------------------------------------------------------
function makeEntry(type: string, overrides?: Partial<ICopilotEventLogEntry>): ICopilotEventLogEntry {
return {
type,
data: {},
id: `id-${Math.random().toString(36).slice(2, 8)}`,
timestamp: new Date().toISOString(),
parentId: null,
...overrides,
};
}
/**
* Builds a minimal event log representing a multi-turn session.
* Each turn = user.message → assistant.turn_start → assistant.message → assistant.turn_end.
*/
function buildTestEventLog(turnCount: number): ICopilotEventLogEntry[] {
const entries: ICopilotEventLogEntry[] = [];
let lastId: string | null = null;
const sessionStart = makeEntry('session.start', {
id: 'session-start-id',
data: { sessionId: 'source-session', context: { cwd: '/test' } },
parentId: null,
});
entries.push(sessionStart);
lastId = sessionStart.id;
for (let turn = 0; turn < turnCount; turn++) {
const userMsg = makeEntry('user.message', {
id: `user-msg-${turn}`,
data: { content: `Turn ${turn} message` },
parentId: lastId,
});
entries.push(userMsg);
lastId = userMsg.id;
const turnStart = makeEntry('assistant.turn_start', {
id: `turn-start-${turn}`,
data: { turnId: String(turn) },
parentId: lastId,
});
entries.push(turnStart);
lastId = turnStart.id;
const assistantMsg = makeEntry('assistant.message', {
id: `assistant-msg-${turn}`,
data: { content: `Response ${turn}` },
parentId: lastId,
});
entries.push(assistantMsg);
lastId = assistantMsg.id;
const turnEnd = makeEntry('assistant.turn_end', {
id: `turn-end-${turn}`,
parentId: lastId,
});
entries.push(turnEnd);
lastId = turnEnd.id;
}
return entries;
}
suite('CopilotAgentForking', () => {
ensureNoDisposablesAreLeakedInTestSuite();
// ---- parseEventLog / serializeEventLog ------------------------------
suite('parseEventLog', () => {
test('parses a single-line JSONL', () => {
const entry = makeEntry('session.start');
const jsonl = JSON.stringify(entry);
const result = parseEventLog(jsonl);
assert.strictEqual(result.length, 1);
assert.strictEqual(result[0].type, 'session.start');
});
test('parses multi-line JSONL', () => {
const entries = [makeEntry('session.start'), makeEntry('user.message')];
const jsonl = entries.map(e => JSON.stringify(e)).join('\n');
const result = parseEventLog(jsonl);
assert.strictEqual(result.length, 2);
});
test('ignores empty lines', () => {
const entry = makeEntry('session.start');
const jsonl = '\n' + JSON.stringify(entry) + '\n\n';
const result = parseEventLog(jsonl);
assert.strictEqual(result.length, 1);
});
test('empty input returns empty array', () => {
assert.deepStrictEqual(parseEventLog(''), []);
assert.deepStrictEqual(parseEventLog('\n\n'), []);
});
});
suite('serializeEventLog', () => {
test('round-trips correctly', () => {
const entries = buildTestEventLog(2);
const serialized = serializeEventLog(entries);
const parsed = parseEventLog(serialized);
assert.strictEqual(parsed.length, entries.length);
for (let i = 0; i < entries.length; i++) {
assert.strictEqual(parsed[i].id, entries[i].id);
assert.strictEqual(parsed[i].type, entries[i].type);
}
});
test('ends with a newline', () => {
const entries = [makeEntry('session.start')];
const serialized = serializeEventLog(entries);
assert.ok(serialized.endsWith('\n'));
});
});
// ---- findTurnBoundaryInEventLog -------------------------------------
suite('findTurnBoundaryInEventLog', () => {
test('finds first turn boundary', () => {
const entries = buildTestEventLog(3);
const boundary = findTurnBoundaryInEventLog(entries, 0);
// Turn 0: user.message(1) + turn_start(2) + assistant.message(3) + turn_end(4)
// Index 4 = turn_end of turn 0
assert.strictEqual(boundary, 4);
assert.strictEqual(entries[boundary].type, 'assistant.turn_end');
assert.strictEqual(entries[boundary].id, 'turn-end-0');
});
test('finds middle turn boundary', () => {
const entries = buildTestEventLog(3);
const boundary = findTurnBoundaryInEventLog(entries, 1);
// Turn 1 ends at index 8
assert.strictEqual(boundary, 8);
assert.strictEqual(entries[boundary].type, 'assistant.turn_end');
assert.strictEqual(entries[boundary].id, 'turn-end-1');
});
test('finds last turn boundary', () => {
const entries = buildTestEventLog(3);
const boundary = findTurnBoundaryInEventLog(entries, 2);
assert.strictEqual(boundary, entries.length - 1);
assert.strictEqual(entries[boundary].type, 'assistant.turn_end');
assert.strictEqual(entries[boundary].id, 'turn-end-2');
});
test('returns -1 for non-existent turn', () => {
const entries = buildTestEventLog(2);
assert.strictEqual(findTurnBoundaryInEventLog(entries, 5), -1);
});
test('returns -1 for empty log', () => {
assert.strictEqual(findTurnBoundaryInEventLog([], 0), -1);
});
});
// ---- buildForkedEventLog --------------------------------------------
suite('buildForkedEventLog', () => {
test('forks at turn 0', () => {
const entries = buildTestEventLog(3);
const forked = buildForkedEventLog(entries, 0, 'new-session-id');
// Should have session.start + turn 0 events (user.message, turn_start, assistant.message, turn_end)
assert.strictEqual(forked.length, 5);
assert.strictEqual(forked[0].type, 'session.start');
assert.strictEqual((forked[0].data as Record<string, unknown>).sessionId, 'new-session-id');
});
test('forks at turn 1', () => {
const entries = buildTestEventLog(3);
const forked = buildForkedEventLog(entries, 1, 'new-session-id');
// session.start + 2 turns × 4 events = 9 events
assert.strictEqual(forked.length, 9);
});
test('generates unique UUIDs', () => {
const entries = buildTestEventLog(2);
const forked = buildForkedEventLog(entries, 0, 'new-session-id');
const ids = new Set(forked.map(e => e.id));
assert.strictEqual(ids.size, forked.length, 'All IDs should be unique');
// None should match the original
for (const entry of forked) {
assert.ok(!entries.some(e => e.id === entry.id), 'Should not reuse original IDs');
}
});
test('re-chains parentId links', () => {
const entries = buildTestEventLog(2);
const forked = buildForkedEventLog(entries, 0, 'new-session-id');
// First event has no parent
assert.strictEqual(forked[0].parentId, null);
// Each subsequent event's parentId should be a valid ID in the forked log
const idSet = new Set(forked.map(e => e.id));
for (let i = 1; i < forked.length; i++) {
assert.ok(
forked[i].parentId !== null && idSet.has(forked[i].parentId!),
`Event ${i} (${forked[i].type}) should have a valid parentId`,
);
}
});
test('strips session.shutdown and session.resume events', () => {
const entries = buildTestEventLog(2);
// Insert lifecycle events
entries.splice(5, 0, makeEntry('session.shutdown', { parentId: entries[4].id }));
entries.splice(6, 0, makeEntry('session.resume', { parentId: entries[5].id }));
const forked = buildForkedEventLog(entries, 1, 'new-session-id');
assert.ok(!forked.some(e => e.type === 'session.shutdown'));
assert.ok(!forked.some(e => e.type === 'session.resume'));
});
test('throws for invalid turn index', () => {
const entries = buildTestEventLog(1);
assert.throws(() => buildForkedEventLog(entries, 5, 'new-session-id'));
});
});
// ---- buildTruncatedEventLog -----------------------------------------
suite('buildTruncatedEventLog', () => {
test('truncates to turn 0', () => {
const entries = buildTestEventLog(3);
const truncated = buildTruncatedEventLog(entries, 0);
// New session.start + turn 0 events (user.message, turn_start, assistant.message, turn_end)
assert.strictEqual(truncated.length, 5);
assert.strictEqual(truncated[0].type, 'session.start');
});
test('truncates to turn 1', () => {
const entries = buildTestEventLog(3);
const truncated = buildTruncatedEventLog(entries, 1);
// New session.start + 2 turns × 4 events = 9 events
assert.strictEqual(truncated.length, 9);
});
test('prepends fresh session.start', () => {
const entries = buildTestEventLog(2);
const truncated = buildTruncatedEventLog(entries, 0);
assert.strictEqual(truncated[0].type, 'session.start');
assert.strictEqual(truncated[0].parentId, null);
// Should not reuse original session.start ID
assert.notStrictEqual(truncated[0].id, entries[0].id);
});
test('re-chains parentId links', () => {
const entries = buildTestEventLog(2);
const truncated = buildTruncatedEventLog(entries, 0);
const idSet = new Set(truncated.map(e => e.id));
for (let i = 1; i < truncated.length; i++) {
assert.ok(
truncated[i].parentId !== null && idSet.has(truncated[i].parentId!),
`Event ${i} (${truncated[i].type}) should have a valid parentId`,
);
}
});
test('strips lifecycle events', () => {
const entries = buildTestEventLog(3);
// Add lifecycle events between turns
entries.splice(5, 0, makeEntry('session.shutdown'));
entries.splice(6, 0, makeEntry('session.resume'));
const truncated = buildTruncatedEventLog(entries, 2);
const lifecycleEvents = truncated.filter(
e => e.type === 'session.shutdown' || e.type === 'session.resume',
);
assert.strictEqual(lifecycleEvents.length, 0);
});
test('throws for invalid turn index', () => {
const entries = buildTestEventLog(1);
assert.throws(() => buildTruncatedEventLog(entries, 5));
});
test('throws when no session.start exists', () => {
const entries = [makeEntry('user.message')];
assert.throws(() => buildTruncatedEventLog(entries, 0));
});
});
// ---- buildWorkspaceYaml ---------------------------------------------
suite('buildWorkspaceYaml', () => {
test('contains required fields', () => {
const yaml = buildWorkspaceYaml('test-id', '/home/user/project', 'Test summary');
assert.ok(yaml.includes('id: test-id'));
assert.ok(yaml.includes('cwd: /home/user/project'));
assert.ok(yaml.includes('summary: Test summary'));
assert.ok(yaml.includes('summary_count: 0'));
assert.ok(yaml.includes('created_at:'));
assert.ok(yaml.includes('updated_at:'));
});
});
// ---- SQLite operations (in-memory) ----------------------------------
suite('forkSessionInDb', () => {
async function openTestDb(): Promise<import('@vscode/sqlite3').Database> {
const sqlite3 = await import('@vscode/sqlite3');
return new Promise((resolve, reject) => {
const db = new sqlite3.default.Database(':memory:', (err: Error | null) => {
if (err) {
return reject(err);
}
resolve(db);
});
});
}
function exec(db: import('@vscode/sqlite3').Database, sql: string): Promise<void> {
return new Promise((resolve, reject) => {
db.exec(sql, err => err ? reject(err) : resolve());
});
}
function all(db: import('@vscode/sqlite3').Database, sql: string, params: unknown[] = []): Promise<Record<string, unknown>[]> {
return new Promise((resolve, reject) => {
db.all(sql, params, (err: Error | null, rows: Record<string, unknown>[]) => {
if (err) {
return reject(err);
}
resolve(rows);
});
});
}
function close(db: import('@vscode/sqlite3').Database): Promise<void> {
return new Promise((resolve, reject) => {
db.close(err => err ? reject(err) : resolve());
});
}
async function setupSchema(db: import('@vscode/sqlite3').Database): Promise<void> {
await exec(db, `
CREATE TABLE sessions (
id TEXT PRIMARY KEY,
cwd TEXT,
repository TEXT,
branch TEXT,
summary TEXT,
created_at TEXT,
updated_at TEXT,
host_type TEXT
);
CREATE TABLE turns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
turn_index INTEGER NOT NULL,
user_message TEXT,
assistant_response TEXT,
timestamp TEXT,
UNIQUE(session_id, turn_index)
);
CREATE TABLE session_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
file_path TEXT,
tool_name TEXT,
turn_index INTEGER,
first_seen_at TEXT
);
CREATE VIRTUAL TABLE search_index USING fts5(
content,
session_id,
source_type,
source_id
);
CREATE TABLE checkpoints (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
checkpoint_number INTEGER,
title TEXT,
overview TEXT,
history TEXT,
work_done TEXT,
technical_details TEXT,
important_files TEXT,
next_steps TEXT,
created_at TEXT
);
`);
}
async function seedTestData(db: import('@vscode/sqlite3').Database, sessionId: string, turnCount: number): Promise<void> {
await exec(db, `
INSERT INTO sessions (id, cwd, repository, branch, summary, created_at, updated_at, host_type)
VALUES ('${sessionId}', '/test', 'test-repo', 'main', 'Test session', '2026-01-01', '2026-01-01', 'github');
`);
for (let i = 0; i < turnCount; i++) {
await exec(db, `
INSERT INTO turns (session_id, turn_index, user_message, assistant_response, timestamp)
VALUES ('${sessionId}', ${i}, 'msg ${i}', 'resp ${i}', '2026-01-01');
`);
await exec(db, `
INSERT INTO session_files (session_id, file_path, tool_name, turn_index, first_seen_at)
VALUES ('${sessionId}', 'file${i}.ts', 'edit', ${i}, '2026-01-01');
`);
}
}
test('copies session metadata', async () => {
const db = await openTestDb();
try {
await setupSchema(db);
await seedTestData(db, 'source', 3);
await forkSessionInDb(db, 'source', 'forked', 1);
const sessions = await all(db, 'SELECT * FROM sessions WHERE id = ?', ['forked']);
assert.strictEqual(sessions.length, 1);
assert.strictEqual(sessions[0].cwd, '/test');
assert.strictEqual(sessions[0].repository, 'test-repo');
} finally {
await close(db);
}
});
test('copies turns up to fork point', async () => {
const db = await openTestDb();
try {
await setupSchema(db);
await seedTestData(db, 'source', 3);
await forkSessionInDb(db, 'source', 'forked', 1);
const turns = await all(db, 'SELECT * FROM turns WHERE session_id = ? ORDER BY turn_index', ['forked']);
assert.strictEqual(turns.length, 2); // turns 0 and 1
assert.strictEqual(turns[0].turn_index, 0);
assert.strictEqual(turns[1].turn_index, 1);
} finally {
await close(db);
}
});
test('copies session files up to fork point', async () => {
const db = await openTestDb();
try {
await setupSchema(db);
await seedTestData(db, 'source', 3);
await forkSessionInDb(db, 'source', 'forked', 1);
const files = await all(db, 'SELECT * FROM session_files WHERE session_id = ?', ['forked']);
assert.strictEqual(files.length, 2); // files from turns 0 and 1
} finally {
await close(db);
}
});
test('does not affect source session', async () => {
const db = await openTestDb();
try {
await setupSchema(db);
await seedTestData(db, 'source', 3);
await forkSessionInDb(db, 'source', 'forked', 1);
const sourceTurns = await all(db, 'SELECT * FROM turns WHERE session_id = ?', ['source']);
assert.strictEqual(sourceTurns.length, 3);
} finally {
await close(db);
}
});
});
suite('truncateSessionInDb', () => {
async function openTestDb(): Promise<import('@vscode/sqlite3').Database> {
const sqlite3 = await import('@vscode/sqlite3');
return new Promise((resolve, reject) => {
const db = new sqlite3.default.Database(':memory:', (err: Error | null) => {
if (err) {
return reject(err);
}
resolve(db);
});
});
}
function exec(db: import('@vscode/sqlite3').Database, sql: string): Promise<void> {
return new Promise((resolve, reject) => {
db.exec(sql, err => err ? reject(err) : resolve());
});
}
function all(db: import('@vscode/sqlite3').Database, sql: string, params: unknown[] = []): Promise<Record<string, unknown>[]> {
return new Promise((resolve, reject) => {
db.all(sql, params, (err: Error | null, rows: Record<string, unknown>[]) => {
if (err) {
return reject(err);
}
resolve(rows);
});
});
}
function close(db: import('@vscode/sqlite3').Database): Promise<void> {
return new Promise((resolve, reject) => {
db.close(err => err ? reject(err) : resolve());
});
}
async function setupSchema(db: import('@vscode/sqlite3').Database): Promise<void> {
await exec(db, `
CREATE TABLE sessions (
id TEXT PRIMARY KEY,
cwd TEXT,
repository TEXT,
branch TEXT,
summary TEXT,
created_at TEXT,
updated_at TEXT,
host_type TEXT
);
CREATE TABLE turns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
turn_index INTEGER NOT NULL,
user_message TEXT,
assistant_response TEXT,
timestamp TEXT,
UNIQUE(session_id, turn_index)
);
CREATE VIRTUAL TABLE search_index USING fts5(
content,
session_id,
source_type,
source_id
);
`);
}
test('removes turns after truncation point', async () => {
const db = await openTestDb();
try {
await setupSchema(db);
await exec(db, `
INSERT INTO sessions (id, cwd, summary, created_at, updated_at)
VALUES ('sess', '/test', 'Test', '2026-01-01', '2026-01-01');
`);
for (let i = 0; i < 5; i++) {
await exec(db, `
INSERT INTO turns (session_id, turn_index, user_message, timestamp)
VALUES ('sess', ${i}, 'msg ${i}', '2026-01-01');
`);
}
await truncateSessionInDb(db, 'sess', 2);
const turns = await all(db, 'SELECT * FROM turns WHERE session_id = ? ORDER BY turn_index', ['sess']);
assert.strictEqual(turns.length, 3); // turns 0, 1, 2
assert.strictEqual(turns[0].turn_index, 0);
assert.strictEqual(turns[2].turn_index, 2);
} finally {
await close(db);
}
});
test('updates session timestamp', async () => {
const db = await openTestDb();
try {
await setupSchema(db);
await exec(db, `
INSERT INTO sessions (id, cwd, summary, created_at, updated_at)
VALUES ('sess', '/test', 'Test', '2026-01-01', '2026-01-01');
`);
await exec(db, `
INSERT INTO turns (session_id, turn_index, user_message, timestamp)
VALUES ('sess', 0, 'msg 0', '2026-01-01');
`);
await truncateSessionInDb(db, 'sess', 0);
const sessions = await all(db, 'SELECT updated_at FROM sessions WHERE id = ?', ['sess']);
assert.notStrictEqual(sessions[0].updated_at, '2026-01-01');
} finally {
await close(db);
}
});
test('removes search index entries for truncated turns', async () => {
const db = await openTestDb();
try {
await setupSchema(db);
await exec(db, `
INSERT INTO sessions (id, cwd, summary, created_at, updated_at)
VALUES ('sess', '/test', 'Test', '2026-01-01', '2026-01-01');
`);
for (let i = 0; i < 3; i++) {
await exec(db, `
INSERT INTO turns (session_id, turn_index, user_message, timestamp)
VALUES ('sess', ${i}, 'msg ${i}', '2026-01-01');
`);
await exec(db, `
INSERT INTO search_index (content, session_id, source_type, source_id)
VALUES ('content ${i}', 'sess', 'turn', 'sess:turn:${i}');
`);
}
await truncateSessionInDb(db, 'sess', 0);
const searchEntries = await all(db, 'SELECT * FROM search_index WHERE session_id = ?', ['sess']);
assert.strictEqual(searchEntries.length, 1);
assert.strictEqual(searchEntries[0].source_id, 'sess:turn:0');
} finally {
await close(db);
}
});
});
});

View File

@@ -346,6 +346,9 @@ export class RemoteAgentHostContribution extends Disposable implements IWorkbenc
canDelegate: true,
requiresCustomModels: true,
supportsDelegation: false,
capabilities: {
supportsCheckpoints: true,
},
}));
// Session handler (unified)

View File

@@ -12,7 +12,7 @@ import { localize, localize2 } from '../../../../../nls.js';
import { Action2, MenuId, registerAction2 } from '../../../../../platform/actions/common/actions.js';
import { ContextKeyExpr } from '../../../../../platform/contextkey/common/contextkey.js';
import { ServicesAccessor } from '../../../../../platform/instantiation/common/instantiation.js';
import { ChatContextKeys } from '../../common/actions/chatContextKeys.js';
import { ChatContextKeyExprs, ChatContextKeys } from '../../common/actions/chatContextKeys.js';
import { IChatService, ResponseModelState } from '../../common/chatService/chatService.js';
import type { ISerializableChatData } from '../../common/model/chatModel.js';
import { isChatTreeItem, isRequestVM, isResponseVM } from '../../common/model/chatViewModel.js';
@@ -40,7 +40,7 @@ export function registerChatForkActions() {
ChatContextKeys.isRequest,
ChatContextKeys.isFirstRequest.negate(),
ContextKeyExpr.or(
ChatContextKeys.lockedToCodingAgent.negate(),
ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession),
ChatContextKeys.chatSessionSupportsFork
)
)

View File

@@ -14,7 +14,7 @@ import { ContextKeyExpr } from '../../../../../platform/contextkey/common/contex
import { IDialogService } from '../../../../../platform/dialogs/common/dialogs.js';
import { KeybindingWeight } from '../../../../../platform/keybinding/common/keybindingsRegistry.js';
import { IViewsService } from '../../../../services/views/common/viewsService.js';
import { ChatContextKeys } from '../../common/actions/chatContextKeys.js';
import { ChatContextKeyExprs, ChatContextKeys } from '../../common/actions/chatContextKeys.js';
import { IChatEditingSession } from '../../common/editing/chatEditingService.js';
import { IChatService } from '../../common/chatService/chatService.js';
import { ChatAgentLocation, ChatConfiguration, ChatModeKind } from '../../common/constants.js';
@@ -278,7 +278,7 @@ export function registerNewChatActions() {
f1: true,
menu: [{
id: MenuId.ChatMessageRestoreCheckpoint,
when: ChatContextKeys.lockedToCodingAgent.negate(),
when: ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession),
group: 'navigation',
order: -1
}]

View File

@@ -147,6 +147,9 @@ export class AgentHostContribution extends Disposable implements IWorkbenchContr
description: agent.description,
canDelegate: true,
requiresCustomModels: true,
capabilities: {
supportsCheckpoints: true,
},
}));
// Session list controller

View File

@@ -46,7 +46,8 @@ interface IAgentHostFileEdit {
interface IAgentHostCheckpoint {
readonly requestId: string;
readonly undoStopId: string;
/** Tool-call ID, or `undefined` for the sentinel checkpoint at request start. */
readonly undoStopId: string | undefined;
readonly edits: IAgentHostFileEdit[];
}
@@ -120,7 +121,26 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS
private readonly _entriesObs = observableValue<readonly AgentHostModifiedFileEntry[]>(this, []);
readonly entries: IObservable<readonly IModifiedFileEntry[]> = this._entriesObs;
readonly requestDisablement: IObservable<IChatRequestDisablement[]> = constObservable([]);
readonly requestDisablement: IObservable<IChatRequestDisablement[]> = derivedOpts(
{ equalsFn: (a, b) => a.length === b.length && a.every((v, i) => v.requestId === b[i].requestId && v.afterUndoStop === b[i].afterUndoStop) },
reader => {
const currentIdx = this._currentCheckpointIndex.read(reader);
if (currentIdx >= this._checkpoints.length - 1) {
return [];
}
// Collect unique request IDs from checkpoints after the current
// index. Keep the first entry per request — if that's the sentinel
// (undoStopId === undefined) the entire request is disabled.
const disabled = new Map<string, string | undefined>();
for (let i = currentIdx + 1; i < this._checkpoints.length; i++) {
const cp = this._checkpoints[i];
if (!disabled.has(cp.requestId)) {
disabled.set(cp.requestId, cp.undoStopId);
}
}
return [...disabled].map(([requestId, afterUndoStop]): IChatRequestDisablement => ({ requestId, afterUndoStop }));
},
);
private readonly _onDidDispose = this._register(new Emitter<void>());
readonly onDidDispose: Event<void> = this._onDidDispose.event;
@@ -154,6 +174,27 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS
// ---- Hydration from protocol state --------------------------------------
/**
* Ensures a sentinel checkpoint exists for the given request. Called at the
* start of every turn so that `requestDisablement` and `restoreSnapshot`
* can reference requests that may not produce any file edits.
*
* Also splices away stale checkpoints after the current index (undo branch
* semantics) when a new request arrives after a checkpoint restore.
*/
ensureRequestCheckpoint(requestId: string): void {
// Splice stale checkpoints if the user restored a checkpoint
const currentIdx = this._currentCheckpointIndex.get();
if (currentIdx < this._checkpoints.length - 1) {
this._checkpoints.splice(currentIdx + 1);
}
// Insert sentinel for this request if it doesn't exist yet
if (!this._checkpoints.some(cp => cp.requestId === requestId)) {
this._checkpoints.push({ requestId, undoStopId: undefined, edits: [] });
}
}
addToolCallEdits(requestId: string, tc: IToolCallState): IChatProgress[] {
if (tc.status !== ToolCallStatus.Completed) {
return [];
@@ -164,6 +205,9 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS
return [];
}
// Ensure the sentinel and undo-branch splice are handled
this.ensureRequestCheckpoint(requestId);
const fileEdits = fileEditsToExternalEdits(tc);
if (fileEdits.length === 0) {
return [];
@@ -250,43 +294,54 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS
if (stopId !== undefined) {
return this._checkpoints.findIndex(cp => cp.requestId === requestId && cp.undoStopId === stopId);
}
// No specific stop: use the last checkpoint for this request
for (let i = this._checkpoints.length - 1; i >= 0; i--) {
if (this._checkpoints[i].requestId === requestId) {
return i;
}
}
return -1;
// No specific stop: find the sentinel checkpoint (undoStopId === undefined)
// for this request, which marks the request boundary.
return this._checkpoints.findIndex(cp => cp.requestId === requestId && cp.undoStopId === undefined);
}
private _findCheckpoint(requestId: string, stopId: string | undefined): IAgentHostCheckpoint | undefined {
const idx = this._findCheckpointIndex(requestId, stopId);
return idx >= 0 ? this._checkpoints[idx] : undefined;
if (stopId !== undefined) {
const idx = this._findCheckpointIndex(requestId, stopId);
return idx >= 0 ? this._checkpoints[idx] : undefined;
}
// No specific stop: find the last non-sentinel checkpoint for this
// request (the one with actual edits).
for (let i = this._checkpoints.length - 1; i >= 0; i--) {
const cp = this._checkpoints[i];
if (cp.requestId === requestId && cp.undoStopId !== undefined) {
return cp;
}
}
return undefined;
}
async restoreSnapshot(requestId: string, stopId: string | undefined): Promise<void> {
const idx = this._findCheckpointIndex(requestId, stopId);
if (idx < 0) {
const cpIdx = this._findCheckpointIndex(requestId, stopId);
if (cpIdx < 0) {
this._logService.warn(`[AgentHostEditingSession] No checkpoint found for requestId=${requestId}${stopId ? `, stopId=${stopId}` : ''}`);
return;
}
// When stopId is undefined we found the sentinel (request boundary).
// Navigate to one before it so the request's edits are fully undone.
const targetIdx = stopId === undefined ? cpIdx - 1 : cpIdx;
// Navigate to the target checkpoint
const currentIdx = this._currentCheckpointIndex.get();
if (idx < currentIdx) {
if (targetIdx < currentIdx) {
// Undo forward checkpoints
for (let i = currentIdx; i > idx; i--) {
for (let i = currentIdx; i > targetIdx; i--) {
await this._writeCheckpointContent(this._checkpoints[i], 'before');
}
} else if (idx > currentIdx) {
} else if (targetIdx > currentIdx) {
// Redo to reach the target
for (let i = currentIdx + 1; i <= idx; i++) {
for (let i = currentIdx + 1; i <= targetIdx; i++) {
await this._writeCheckpointContent(this._checkpoints[i], 'after');
}
}
transaction(tx => {
this._currentCheckpointIndex.set(idx, tx);
this._currentCheckpointIndex.set(targetIdx, tx);
});
this._rebuildEntries();
}
@@ -558,8 +613,14 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS
await this._writeCheckpointContent(this._checkpoints[idx], 'before');
// Skip past any sentinel checkpoints (they have no edits)
let newIdx = idx - 1;
while (newIdx >= 0 && this._checkpoints[newIdx].undoStopId === undefined) {
newIdx--;
}
transaction(tx => {
this._currentCheckpointIndex.set(idx - 1, tx);
this._currentCheckpointIndex.set(newIdx, tx);
});
this._rebuildEntries();
}
@@ -570,7 +631,15 @@ export class AgentHostEditingSession extends Disposable implements IChatEditingS
return;
}
const nextIdx = idx + 1;
// Skip past sentinel checkpoints to the next tool-call checkpoint
let nextIdx = idx + 1;
while (nextIdx < this._checkpoints.length && this._checkpoints[nextIdx].undoStopId === undefined) {
nextIdx++;
}
if (nextIdx >= this._checkpoints.length) {
return;
}
await this._writeCheckpointContent(this._checkpoints[nextIdx], 'after');
transaction(tx => {

View File

@@ -12,9 +12,9 @@ import { ResourceMap } from '../../../../../../base/common/map.js';
import { observableValue } from '../../../../../../base/common/observable.js';
import { isEqual } from '../../../../../../base/common/resources.js';
import { URI } from '../../../../../../base/common/uri.js';
import { generateUuid } from '../../../../../../base/common/uuid.js';
import { localize } from '../../../../../../nls.js';
import { AgentProvider, AgentSession, IAgentAttachment, type IAgentConnection } from '../../../../../../platform/agentHost/common/agentService.js';
import { ISessionTruncatedAction } from '../../../../../../platform/agentHost/common/state/protocol/actions.js';
import { ActionType, isSessionAction, type ISessionAction } from '../../../../../../platform/agentHost/common/state/sessionActions.js';
import { SessionClientState } from '../../../../../../platform/agentHost/common/state/sessionClientState.js';
import { AHP_AUTH_REQUIRED, ProtocolError } from '../../../../../../platform/agentHost/common/state/sessionProtocol.js';
@@ -26,7 +26,7 @@ import { ILogService } from '../../../../../../platform/log/common/log.js';
import { IProductService } from '../../../../../../platform/product/common/productService.js';
import { IWorkspaceContextService } from '../../../../../../platform/workspace/common/workspace.js';
import { ChatRequestQueueKind, IChatProgress, IChatService, IChatToolInvocation, ToolConfirmKind } from '../../../common/chatService/chatService.js';
import { IChatSession, IChatSessionContentProvider, IChatSessionHistoryItem } from '../../../common/chatSessionsService.js';
import { IChatSession, IChatSessionContentProvider, IChatSessionHistoryItem, IChatSessionItem, IChatSessionRequestHistoryItem } from '../../../common/chatSessionsService.js';
import { ChatAgentLocation, ChatModeKind } from '../../../common/constants.js';
import { IChatEditingService } from '../../../common/editing/chatEditingService.js';
import { ChatToolInvocation } from '../../../common/model/chatProgressTypes/chatToolInvocation.js';
@@ -59,11 +59,13 @@ class AgentHostChatSession extends Disposable implements IChatSession {
readonly requestHandler: IChatSession['requestHandler'];
interruptActiveResponseCallback: IChatSession['interruptActiveResponseCallback'];
readonly forkSession: IChatSession['forkSession'];
constructor(
readonly sessionResource: URI,
readonly history: readonly IChatSessionHistoryItem[],
private readonly _sendRequest: (request: IChatAgentRequest, progress: (parts: IChatProgress[]) => void, token: CancellationToken) => Promise<void>,
private readonly _forkSession: ((request: IChatSessionRequestHistoryItem | undefined, token: CancellationToken) => Promise<IChatSessionItem>) | undefined,
initialProgress: IChatProgress[] | undefined,
onDispose: () => void,
@ILogService private readonly _logService: ILogService,
@@ -91,6 +93,8 @@ class AgentHostChatSession extends Disposable implements IChatSession {
this.interruptActiveResponseCallback = (hasActiveTurn || history.length === 0) ? async () => {
return true;
} : undefined;
this.forkSession = this._forkSession;
}
/**
@@ -291,6 +295,9 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC
this._ensurePendingMessageSubscription(sessionResource, backendSession);
return this._handleTurn(backendSession, request, progress, token);
},
resolvedSession
? (request, token) => this._forkSession(sessionResource, resolvedSession!, request, token)
: undefined,
initialProgress,
() => {
this._activeSessions.delete(sessionResource);
@@ -724,7 +731,7 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC
return;
}
const turnId = generateUuid();
const turnId = request.requestId;
this._clientDispatchedTurnIds.add(turnId);
const cleanUpTurnId = () => this._clientDispatchedTurnIds.delete(turnId);
const attachments = this._convertVariablesToAttachments(request);
@@ -751,6 +758,36 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC
}
}
// If the chat model has fewer previous requests than the protocol has
// turns, a checkpoint was restored or a message was edited. Dispatch
// session/truncated so the server drops the stale tail.
const chatModel = this._chatService.getSession(request.sessionResource);
const protocolState = this._clientState.getSessionState(session.toString());
if (chatModel && protocolState && protocolState.turns.length > 0) {
// -2 since -1 will already be the current request
const previousRequestIndex = chatModel.getRequests().findIndex(i => i.id === request.requestId) - 1;
const previousRequest = previousRequestIndex >= 0 ? chatModel.getRequests()[previousRequestIndex] : undefined;
if (!previousRequest && protocolState.turns.length > 0) {
const truncateAction: ISessionTruncatedAction = {
type: ActionType.SessionTruncated,
session: session.toString(),
};
const truncateSeq = this._clientState.applyOptimistic(truncateAction);
this._config.connection.dispatchAction(truncateAction, this._clientState.clientId, truncateSeq);
} else {
const seenAtIndex = protocolState.turns.findIndex(t => t.id === previousRequest!.id);
if (seenAtIndex !== -1 && seenAtIndex < protocolState.turns.length - 1) {
const truncateAction: ISessionTruncatedAction = {
type: ActionType.SessionTruncated,
session: session.toString(),
turnId: previousRequest!.id,
};
const truncateSeq = this._clientState.applyOptimistic(truncateAction);
this._config.connection.dispatchAction(truncateAction, this._clientState.clientId, truncateSeq);
}
}
}
// Dispatch session/turnStarted — the server will call sendMessage on
// the provider as a side effect.
const turnAction = {
@@ -765,6 +802,12 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC
const clientSeq = this._clientState.applyOptimistic(turnAction);
this._config.connection.dispatchAction(turnAction, this._clientState.clientId, clientSeq);
// Ensure the editing session records a sentinel checkpoint for this
// request so it appears in requestDisablement even if the turn
// produces no file edits.
this._ensureEditingSession(request.sessionResource)
?.ensureRequestCheckpoint(request.requestId);
// Track live ChatToolInvocation objects for this turn
const activeToolInvocations = new Map<string, ChatToolInvocation>();
@@ -1211,14 +1254,68 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC
return AgentSession.uri(this._config.provider, rawId);
}
/**
* Forks a session at the given request point by creating a new backend
* session with the `fork` parameter. Returns an {@link IChatSessionItem}
* pointing to the newly created session.
*/
private async _forkSession(
sessionResource: URI,
backendSession: URI,
request: IChatSessionRequestHistoryItem | undefined,
token: CancellationToken,
): Promise<IChatSessionItem> {
if (token.isCancellationRequested) {
throw new Error('Cancelled');
}
// Determine the turn index to fork at. If a specific request is
// provided, find its position in the protocol state's turn list.
// Otherwise fork the entire session.
const protocolState = this._clientState.getSessionState(backendSession.toString());
let turnIndex: number | undefined;
if (request) {
turnIndex = protocolState?.turns.findIndex(t => t.id === request.id);
if (turnIndex === undefined || turnIndex < 0) {
throw new Error(`Cannot fork: turn for request ${request.id} not found in protocol state`);
}
} else if (protocolState && protocolState.turns.length > 0) {
turnIndex = protocolState.turns.length - 1;
}
if (turnIndex === undefined) {
throw new Error('Cannot fork: no turns to fork from');
}
const chatModel = this._chatService.getSession(sessionResource);
const forkedSession = await this._createAndSubscribe(sessionResource, undefined, {
session: backendSession,
turnIndex,
});
const forkedRawId = AgentSession.id(forkedSession);
const forkedResource = URI.from({ scheme: this._config.sessionType, path: `/${forkedRawId}` });
const now = Date.now();
return {
resource: forkedResource,
label: chatModel?.title
? localize('chat.forked.title', "Forked: {0}", chatModel.title)
: `Forked session`,
iconPath: getAgentHostIcon(this._productService),
timing: { created: now, lastRequestStarted: now, lastRequestEnded: now },
};
}
/** Creates a new backend session and subscribes to its state. */
private async _createAndSubscribe(sessionResource: URI, modelId?: string): Promise<URI> {
private async _createAndSubscribe(sessionResource: URI, modelId?: string, fork?: { session: URI; turnIndex: number }): Promise<URI> {
const rawModelId = this._extractRawModelId(modelId);
const resourceKey = sessionResource.path.substring(1);
const workingDirectory = this._config.resolveWorkingDirectory?.(resourceKey)
?? this._workspaceContextService.getWorkspace().folders[0]?.uri;
this._logService.trace(`[AgentHost] Creating new session, model=${rawModelId ?? '(default)'}, provider=${this._config.provider}`);
this._logService.trace(`[AgentHost] Creating new session, model=${rawModelId ?? '(default)'}, provider=${this._config.provider}${fork ? `, fork from ${fork.session.toString()} at index ${fork.turnIndex}` : ''}`);
let session: URI;
try {
@@ -1226,6 +1323,7 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC
model: rawModelId,
provider: this._config.provider,
workingDirectory,
fork,
});
} catch (err) {
// If authentication is required, try to resolve it and retry once
@@ -1237,6 +1335,7 @@ export class AgentHostSessionHandler extends Disposable implements IChatSessionC
model: rawModelId,
provider: this._config.provider,
workingDirectory,
fork,
});
} else {
throw new Error(localize('agentHost.authRequired', "Authentication is required to start a session. Please sign in and try again."));

View File

@@ -30,7 +30,7 @@ import { IEditorService } from '../../../../services/editor/common/editorService
import { IAgentSessionsService } from '../agentSessions/agentSessionsService.js';
import { IChatRequestVariableEntry, isImplicitVariableEntry, isPromptFileVariableEntry, isPromptTextVariableEntry, isStringVariableEntry, isWorkspaceVariableEntry } from '../../common/attachments/chatVariableEntries.js';
import { isChatViewTitleActionContext } from '../../common/actions/chatActions.js';
import { ChatContextKeys } from '../../common/actions/chatContextKeys.js';
import { ChatContextKeyExprs, ChatContextKeys } from '../../common/actions/chatContextKeys.js';
import { applyingChatEditsFailedContextKey, CHAT_EDITING_MULTI_DIFF_SOURCE_RESOLVER_SCHEME, chatEditingResourceContextKey, chatEditingWidgetFileStateContextKey, decidedChatEditingResourceContextKey, hasAppliedChatEditsContextKey, hasUndecidedChatEditingResourceContextKey, IChatEditingService, IChatEditingSession, ModifiedFileEntryState } from '../../common/editing/chatEditingService.js';
import { IChatService } from '../../common/chatService/chatService.js';
import { isChatTreeItem, isRequestVM, isResponseVM } from '../../common/model/chatViewModel.js';
@@ -533,7 +533,7 @@ registerAction2(class RemoveAction extends Action2 {
id: MenuId.ChatMessageTitle,
group: 'navigation',
order: 2,
when: ContextKeyExpr.and(ContextKeyExpr.equals(`config.${ChatConfiguration.EditRequests}`, 'input').negate(), ContextKeyExpr.equals(`config.${ChatConfiguration.CheckpointsEnabled}`, false), ChatContextKeys.lockedToCodingAgent.negate()),
when: ContextKeyExpr.and(ContextKeyExpr.equals(`config.${ChatConfiguration.EditRequests}`, 'input').negate(), ContextKeyExpr.equals(`config.${ChatConfiguration.CheckpointsEnabled}`, false), ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession)),
}
]
});
@@ -586,7 +586,7 @@ registerAction2(class RestoreCheckpointAction extends Action2 {
id: MenuId.ChatMessageCheckpoint,
group: 'navigation',
order: 2,
when: ContextKeyExpr.and(ChatContextKeys.isRequest, ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeys.isFirstRequest.negate())
when: ContextKeyExpr.and(ChatContextKeys.isRequest, ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession), ChatContextKeys.isFirstRequest.negate())
}
]
});
@@ -633,7 +633,7 @@ registerAction2(class StartOverAction extends Action2 {
id: MenuId.ChatMessageCheckpoint,
group: 'navigation',
order: 2,
when: ContextKeyExpr.and(ChatContextKeys.isRequest, ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeys.isFirstRequest)
when: ContextKeyExpr.and(ChatContextKeys.isRequest, ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession), ChatContextKeys.isFirstRequest)
}
]
});
@@ -667,7 +667,7 @@ registerAction2(class RestoreLastCheckpoint extends Action2 {
precondition: ContextKeyExpr.and(
ChatContextKeys.inChatSession,
ContextKeyExpr.equals(`config.${ChatConfiguration.CheckpointsEnabled}`, true),
ChatContextKeys.lockedToCodingAgent.negate()
ContextKeyExpr.or(ChatContextKeys.lockedToCodingAgent.negate(), ChatContextKeyExprs.isAgentHostSession)
)
});
}

View File

@@ -199,6 +199,7 @@ const supportsAllAttachments: Required<IChatAgentAttachmentCapabilities> = {
supportsTerminalAttachments: true,
supportsPromptAttachments: true,
supportsHandOffs: true,
supportsCheckpoints: true,
};
const DISCLAIMER = localize('chatDisclaimer', "AI responses may be inaccurate");
@@ -2159,7 +2160,8 @@ export class ChatWidget extends Disposable implements IChatWidget {
// Update capabilities for the locked agent
const agent = this.chatAgentService.getAgent(agentId);
this._updateAgentCapabilitiesContextKeys(agent);
this.listWidget?.updateRendererOptions({ restorable: false, editable: false, noFooter: true, progressMessageAtBottomOfResponse: true });
const supportsCheckpoints = this._attachmentCapabilities.supportsCheckpoints ?? false;
this.listWidget?.updateRendererOptions({ restorable: supportsCheckpoints, editable: supportsCheckpoints, noFooter: true, progressMessageAtBottomOfResponse: true });
if (this.visible) {
this.listWidget?.rerender();
}
@@ -2337,8 +2339,17 @@ export class ChatWidget extends Disposable implements IChatWidget {
options.queue = undefined;
}
// For agents that support checkpoints, preserve the checkpoint
// through finishedEditing so blocked requests are removed below
// and the agent host can dispatch a protocol truncation action.
const preserveCheckpoint = this._lockedAgent && !!this._attachmentCapabilities.supportsCheckpoints;
if (preserveCheckpoint) {
this.recentlyRestoredCheckpoint = true;
}
this.finishedEditing(true);
this.viewModel.model?.setCheckpoint(undefined);
if (!preserveCheckpoint) {
this.viewModel.model?.setCheckpoint(undefined);
}
}
const model = this.viewModel.model;
@@ -2408,6 +2419,7 @@ export class ChatWidget extends Disposable implements IChatWidget {
this.chatService.removeRequest(this.viewModel.sessionResource, request.id);
}
}
this.viewModel.model.setCheckpoint(undefined);
}
// Expand directory attachments: extract images as binary entries
const resolvedImageVariables = await this._resolveDirectoryImageAttachments(requestInputs.attachedContext.asArray());

View File

@@ -161,6 +161,15 @@ export namespace ChatContextKeyExprs {
ChatContextKeys.chatModeKind.isEqualTo(ChatModeKind.Agent),
);
/**
* True when the locked coding agent is an agent host session (agent-host-* or remote-*).
* These sessions use AgentHostEditingSession which supports checkpoint-based undo/redo.
*/
export const isAgentHostSession = ContextKeyExpr.or(
ContextKeyExpr.regex(ChatContextKeys.lockedCodingAgentId.key, /^agent-host-/),
ContextKeyExpr.regex(ChatContextKeys.lockedCodingAgentId.key, /^remote-/),
);
/**
* Context expression that indicates when the welcome/setup view should be shown
*/

View File

@@ -50,6 +50,7 @@ export interface IChatAgentAttachmentCapabilities {
supportsTerminalAttachments?: boolean;
supportsPromptAttachments?: boolean;
supportsHandOffs?: boolean;
supportsCheckpoints?: boolean;
}
export interface IChatAgentData {

View File

@@ -675,4 +675,326 @@ suite('AgentHostEditingSession', () => {
assert.strictEqual(diff!.isFinal, true);
});
});
suite('requestDisablement', () => {
test('returns empty when at the latest checkpoint', () => {
const session = createSession(store, new Map());
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: 'b',
afterURI: 'a',
}));
assert.deepStrictEqual(session.requestDisablement.get(), []);
});
test('disables requests after undo', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
const session = createSession(store, contentMap);
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
session.addToolCallEdits('req-2', makeToolCall({
toolCallId: 'tc-2',
filePath: '/workspace/file.ts',
beforeURI: 'content://after-1',
afterURI: 'content://after-2',
}));
// Undo the last tool call — req-2's sentinel + tool call are now after the cursor
await session.undoInteraction();
const disabled = session.requestDisablement.get();
assert.strictEqual(disabled.length, 1);
assert.strictEqual(disabled[0].requestId, 'req-2');
// The first checkpoint after current is the sentinel (undoStopId === undefined)
assert.strictEqual(disabled[0].afterUndoStop, undefined);
});
test('clears disablement after redo', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
contentMap.set(toAgentHostUri(URI.parse('content://after-2'), 'local').toString(), 'after-2');
const session = createSession(store, contentMap);
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
session.addToolCallEdits('req-2', makeToolCall({
toolCallId: 'tc-2',
filePath: '/workspace/file.ts',
beforeURI: 'content://after-1',
afterURI: 'content://after-2',
}));
await session.undoInteraction();
assert.strictEqual(session.requestDisablement.get().length, 1);
await session.redoInteraction();
assert.deepStrictEqual(session.requestDisablement.get(), []);
});
});
suite('restoreSnapshot', () => {
test('restoreSnapshot with undefined stopId navigates before the request', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
contentMap.set(toAgentHostUri(URI.parse('content://after-1'), 'local').toString(), 'after-1');
const session = createSession(store, contentMap);
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
session.addToolCallEdits('req-2', makeToolCall({
toolCallId: 'tc-2',
filePath: '/workspace/file.ts',
beforeURI: 'content://after-1',
afterURI: 'content://after-2',
}));
// Restore to before req-2 — should undo req-2's edits
const writes: IWriteFileParams[] = [];
store.add(session.onDidRequestFileWrite(p => writes.push(p)));
await session.restoreSnapshot('req-2', undefined);
// req-2 has a tool checkpoint whose before-content should be written
assert.ok(writes.length > 0);
// Entries should only show req-1's edits
assert.strictEqual(session.entries.get().length, 1);
assert.strictEqual(session.entries.get()[0].lastModifyingRequestId, 'req-1');
// req-2 should be disabled
assert.strictEqual(session.requestDisablement.get().length, 1);
assert.strictEqual(session.requestDisablement.get()[0].requestId, 'req-2');
});
test('restoreSnapshot with stopId navigates to that checkpoint', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
contentMap.set(toAgentHostUri(URI.parse('content://after-1'), 'local').toString(), 'after-1');
const session = createSession(store, contentMap);
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-2',
filePath: '/workspace/file.ts',
beforeURI: 'content://after-1',
afterURI: 'content://after-2',
}));
// Restore to specific tool call tc-1 within req-1
await session.restoreSnapshot('req-1', 'tc-1');
// Should keep tc-1 but not tc-2
assert.strictEqual(session.canUndo.get(), true);
assert.strictEqual(session.canRedo.get(), true);
});
test('restoreSnapshot for first request navigates to before all edits', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
const session = createSession(store, contentMap);
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
await session.restoreSnapshot('req-1', undefined);
// No entries visible, all disabled
assert.deepStrictEqual(session.entries.get(), []);
assert.strictEqual(session.canUndo.get(), false);
assert.strictEqual(session.requestDisablement.get().length, 1);
assert.strictEqual(session.requestDisablement.get()[0].requestId, 'req-1');
});
});
suite('undo branch (splice stale checkpoints)', () => {
test('new edits after undo remove stale checkpoints', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
const session = createSession(store, contentMap);
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
session.addToolCallEdits('req-2', makeToolCall({
toolCallId: 'tc-2',
filePath: '/workspace/file.ts',
beforeURI: 'content://after-1',
afterURI: 'content://after-2',
}));
// Undo last checkpoint
await session.undoInteraction();
// Now add a new edit — should splice away req-2's sentinel + checkpoint
session.addToolCallEdits('req-3', makeToolCall({
toolCallId: 'tc-3',
filePath: '/workspace/file.ts',
beforeURI: 'content://after-1',
afterURI: 'content://after-3',
}));
// req-2 should be gone, req-3 should be present
assert.strictEqual(session.canRedo.get(), false);
assert.strictEqual(session.requestDisablement.get().length, 0);
assert.strictEqual(session.hasEditsInRequest('req-2'), false);
assert.strictEqual(session.hasEditsInRequest('req-3'), true);
});
test('sentinel checkpoint is preserved after splice for new request', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
const session = createSession(store, contentMap);
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
// Undo
await session.undoInteraction();
// Add new request — sentinel should survive the splice
session.addToolCallEdits('req-2', makeToolCall({
toolCallId: 'tc-2',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-2',
}));
// Undo once (tc-2), then check that req-2 is disabled via the sentinel
await session.undoInteraction();
const disabled = session.requestDisablement.get();
assert.strictEqual(disabled.length, 1);
assert.strictEqual(disabled[0].requestId, 'req-2');
assert.strictEqual(disabled[0].afterUndoStop, undefined);
});
});
suite('ensureRequestCheckpoint', () => {
test('creates sentinel for request without tool calls', () => {
const session = createSession(store, new Map());
session.ensureRequestCheckpoint('req-1');
// No entries visible (sentinel has no edits)
assert.deepStrictEqual(session.entries.get(), []);
// hasEditsInRequest returns true because the sentinel exists
assert.strictEqual(session.hasEditsInRequest('req-1'), true);
});
test('request without edits appears in requestDisablement after undo', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
const session = createSession(store, contentMap);
// req-1 has file edits
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
// req-2 has no file edits, only a sentinel
session.ensureRequestCheckpoint('req-2');
// Undo tc-1 — both req-1 (partially) and req-2 should be disabled
await session.undoInteraction();
const disabled = session.requestDisablement.get();
const disabledIds = disabled.map(d => d.requestId);
assert.ok(disabledIds.includes('req-2'), 'req-2 should be disabled');
});
test('restoreSnapshot works for request with only sentinel', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
const session = createSession(store, contentMap);
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
session.ensureRequestCheckpoint('req-2');
// Restore to before req-2 — should keep req-1's edits but disable req-2
await session.restoreSnapshot('req-2', undefined);
assert.strictEqual(session.entries.get().length, 1);
assert.strictEqual(session.entries.get()[0].lastModifyingRequestId, 'req-1');
assert.strictEqual(session.requestDisablement.get().length, 1);
assert.strictEqual(session.requestDisablement.get()[0].requestId, 'req-2');
});
test('idempotent — calling twice does not duplicate', () => {
const session = createSession(store, new Map());
session.ensureRequestCheckpoint('req-1');
session.ensureRequestCheckpoint('req-1');
// Should still work — only one sentinel
assert.strictEqual(session.hasEditsInRequest('req-1'), true);
});
test('splices stale checkpoints when called after restore', async () => {
const contentMap = new Map<string, string>();
contentMap.set(toAgentHostUri(URI.file('/workspace/file.ts'), 'local').toString(), 'before');
const session = createSession(store, contentMap);
session.addToolCallEdits('req-1', makeToolCall({
toolCallId: 'tc-1',
filePath: '/workspace/file.ts',
beforeURI: URI.file('/workspace/file.ts').toString(),
afterURI: 'content://after-1',
}));
session.ensureRequestCheckpoint('req-2');
// Undo to before req-2's sentinel
await session.undoInteraction();
// New request should splice away req-2
session.ensureRequestCheckpoint('req-3');
assert.strictEqual(session.hasEditsInRequest('req-2'), false);
assert.strictEqual(session.hasEditsInRequest('req-3'), true);
});
});
});