Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions lib/commands/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ function analyzeTokens(state: SessionState, messages: WithParts[]): TokenBreakdo
allMessageIds.add(msg.info.id)
const parts = Array.isArray(msg.parts) ? msg.parts : []
const isCompacted = isMessageCompacted(state, msg)
const isMessagePruned = state.prune.messages.has(msg.info.id)
const pruneEntry = state.prune.messages.byMessageId.get(msg.info.id)
const isMessagePruned = !!pruneEntry && pruneEntry.activeBlockIds.length > 0
const isIgnoredUser = msg.info.role === "user" && isIgnoredUserMessage(msg)

for (const part of parts) {
Expand Down Expand Up @@ -190,8 +191,8 @@ function analyzeTokens(state: SessionState, messages: WithParts[]): TokenBreakdo
const toolsInContextCount = [...activeToolIds].filter((id) => !prunedByToolIds.has(id)).length

let prunedMessageCount = 0
for (const id of state.prune.messages.keys()) {
if (allMessageIds.has(id)) {
for (const [id, entry] of state.prune.messages.byMessageId) {
if (allMessageIds.has(id) && entry.activeBlockIds.length > 0) {
prunedMessageCount++
}
}
Expand Down
19 changes: 17 additions & 2 deletions lib/commands/stats.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,23 @@ export async function handleStatsCommand(ctx: StatsCommandContext): Promise<void

// Session stats from in-memory state
const sessionTokens = state.stats.totalPruneTokens
const sessionTools = state.prune.tools.size
const sessionMessages = state.prune.messages.size

const prunedToolIds = new Set<string>(state.prune.tools.keys())
for (const block of state.prune.messages.blocksById.values()) {
if (block.active) {
for (const toolId of block.effectiveToolIds) {
prunedToolIds.add(toolId)
}
}
}
const sessionTools = prunedToolIds.size

let sessionMessages = 0
for (const entry of state.prune.messages.byMessageId.values()) {
if (entry.activeBlockIds.length > 0) {
sessionMessages++
}
}

// All-time stats from storage files
const allTime = await loadAllSessionStats(logger)
Expand Down
3 changes: 2 additions & 1 deletion lib/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { Logger } from "./logger"
import type { PluginConfig } from "./config"
import { assignMessageRefs } from "./message-ids"
import { syncToolCache } from "./state/tool-cache"
import { prune, insertCompressNudges, insertMessageIds } from "./messages"
import { prune, syncCompressionBlocks, insertCompressNudges, insertMessageIds } from "./messages"
import { buildToolIdList, isIgnoredUserMessage, stripHallucinations } from "./messages/utils"
import { checkSession } from "./state"
import { renderSystemPrompt } from "./prompts"
Expand Down Expand Up @@ -111,6 +111,7 @@ export function createChatMessageTransformHandler(
stripHallucinations(output.messages)
cacheSystemPromptTokens(state, output.messages)
assignMessageRefs(state, output.messages)
syncCompressionBlocks(state, logger, output.messages)
syncToolCache(state, config, logger, output.messages)
buildToolIdList(state, output.messages)
prune(state, logger, config, output.messages)
Expand Down
1 change: 1 addition & 0 deletions lib/messages/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { prune } from "./prune"
export { syncCompressionBlocks } from "./sync"
export { insertCompressNudges } from "./inject/inject"
export { insertMessageIds } from "./inject/inject"
18 changes: 14 additions & 4 deletions lib/messages/prune.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ const filterCompressedRanges = (
logger: Logger,
messages: WithParts[],
): void => {
if (!state.prune.messages?.size) {
if (
state.prune.messages.byMessageId.size === 0 &&
state.prune.messages.activeByAnchorMessageId.size === 0
) {
return
}

Expand All @@ -178,10 +181,16 @@ const filterCompressedRanges = (
const msgId = msg.info.id

// Check if there's a summary to inject at this anchor point
const summary = state.compressSummaries?.find((s) => s?.anchorMessageId === msgId)
const blockId = state.prune.messages.activeByAnchorMessageId.get(msgId)
const summary =
blockId !== undefined ? state.prune.messages.blocksById.get(blockId) : undefined
if (summary) {
const rawSummaryContent = (summary as { summary?: unknown }).summary
if (typeof rawSummaryContent !== "string" || rawSummaryContent.length === 0) {
if (
summary.active !== true ||
typeof rawSummaryContent !== "string" ||
rawSummaryContent.length === 0
) {
logger.warn("Skipping malformed compress summary", {
anchorMessageId: msgId,
blockId: (summary as { blockId?: unknown }).blockId,
Expand Down Expand Up @@ -217,7 +226,8 @@ const filterCompressedRanges = (
}

// Skip messages that are in the prune list
if (state.prune.messages.has(msgId)) {
const pruneEntry = state.prune.messages.byMessageId.get(msgId)
if (pruneEntry && pruneEntry.activeBlockIds.length > 0) {
continue
}

Expand Down
115 changes: 115 additions & 0 deletions lib/messages/sync.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import type { SessionState, WithParts } from "../state"
import type { Logger } from "../logger"

function sortBlocksByCreation(
a: { createdAt: number; blockId: number },
b: { createdAt: number; blockId: number },
): number {
const createdAtDiff = a.createdAt - b.createdAt
if (createdAtDiff !== 0) {
return createdAtDiff
}
return a.blockId - b.blockId
}

export const syncCompressionBlocks = (
state: SessionState,
logger: Logger,
messages: WithParts[],
): void => {
const messagesState = state.prune.messages
if (!messagesState?.blocksById?.size) {
return
}

const messageIds = new Set(messages.map((msg) => msg.info.id))
const previousActiveBlockIds = new Set<number>(
Array.from(messagesState.blocksById.values())
.filter((block) => block.active)
.map((block) => block.blockId),
)

messagesState.activeBlockIds.clear()
messagesState.activeByAnchorMessageId.clear()

const now = Date.now()
const missingOriginBlockIds: number[] = []
const orderedBlocks = Array.from(messagesState.blocksById.values()).sort(sortBlocksByCreation)

for (const block of orderedBlocks) {
const hasOriginMessage =
typeof block.compressMessageId === "string" &&
block.compressMessageId.length > 0 &&
messageIds.has(block.compressMessageId)

if (!hasOriginMessage) {
block.active = false
block.deactivatedAt = now
block.deactivatedByBlockId = undefined
missingOriginBlockIds.push(block.blockId)
continue
}

for (const consumedBlockId of block.consumedBlockIds) {
if (!messagesState.activeBlockIds.has(consumedBlockId)) {
continue
}

const consumedBlock = messagesState.blocksById.get(consumedBlockId)
if (consumedBlock) {
consumedBlock.active = false
consumedBlock.deactivatedAt = now
consumedBlock.deactivatedByBlockId = block.blockId

const mappedBlockId = messagesState.activeByAnchorMessageId.get(
consumedBlock.anchorMessageId,
)
if (mappedBlockId === consumedBlock.blockId) {
messagesState.activeByAnchorMessageId.delete(consumedBlock.anchorMessageId)
}
}

messagesState.activeBlockIds.delete(consumedBlockId)
}

block.active = true
block.deactivatedAt = undefined
block.deactivatedByBlockId = undefined
messagesState.activeBlockIds.add(block.blockId)
if (messageIds.has(block.anchorMessageId)) {
messagesState.activeByAnchorMessageId.set(block.anchorMessageId, block.blockId)
}
}

for (const entry of messagesState.byMessageId.values()) {
const allBlockIds = Array.isArray(entry.allBlockIds)
? [...new Set(entry.allBlockIds.filter((id) => Number.isInteger(id) && id > 0))]
: []

entry.allBlockIds = allBlockIds
entry.activeBlockIds = allBlockIds.filter((id) => messagesState.activeBlockIds.has(id))
}

const nextActiveBlockIds = messagesState.activeBlockIds
let deactivatedCount = 0
let reactivatedCount = 0

for (const blockId of previousActiveBlockIds) {
if (!nextActiveBlockIds.has(blockId)) {
deactivatedCount++
}
}
for (const blockId of nextActiveBlockIds) {
if (!previousActiveBlockIds.has(blockId)) {
reactivatedCount++
}
}

if (missingOriginBlockIds.length > 0 || deactivatedCount > 0 || reactivatedCount > 0) {
logger.info("Synced compress block state", {
missingOriginCount: missingOriginBlockIds.length,
deactivatedCount,
reactivatedCount,
})
}
}
3 changes: 2 additions & 1 deletion lib/shared-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ export const isMessageCompacted = (state: SessionState, msg: WithParts): boolean
if (msg.info.time.created < state.lastCompaction) {
return true
}
if (state.prune.messages.has(msg.info.id)) {
const pruneEntry = state.prune.messages.byMessageId.get(msg.info.id)
if (pruneEntry && pruneEntry.activeBlockIds.length > 0) {
return true
}
return false
Expand Down
68 changes: 31 additions & 37 deletions lib/state/persistence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,21 @@ import * as fs from "fs/promises"
import { existsSync } from "fs"
import { homedir } from "os"
import { join } from "path"
import type { SessionState, SessionStats, CompressSummary } from "./types"
import type { CompressionBlock, PrunedMessageEntry, SessionState, SessionStats } from "./types"
import type { Logger } from "../logger"

/** Prune state as stored on disk */
export interface PersistedPruneMessagesState {
byMessageId: Record<string, PrunedMessageEntry>
blocksById: Record<string, CompressionBlock>
activeBlockIds: number[]
activeByAnchorMessageId: Record<string, number>
nextBlockId: number
}

export interface PersistedPrune {
// New format: tool/message IDs with token counts
tools?: Record<string, number>
messages?: Record<string, number>
// Legacy format: plain ID arrays (backward compatibility)
toolIds?: string[]
messageIds?: string[]
messages?: PersistedPruneMessagesState
}

export interface PersistedNudges {
Expand All @@ -30,7 +34,6 @@ export interface PersistedNudges {
export interface PersistedSessionState {
sessionName?: string
prune: PersistedPrune
compressSummaries: CompressSummary[]
nudges: PersistedNudges
stats: SessionStats
lastUpdated: string
Expand Down Expand Up @@ -70,9 +73,20 @@ export async function saveSessionState(
sessionName: sessionName,
prune: {
tools: Object.fromEntries(sessionState.prune.tools),
messages: Object.fromEntries(sessionState.prune.messages),
messages: {
byMessageId: Object.fromEntries(sessionState.prune.messages.byMessageId),
blocksById: Object.fromEntries(
Array.from(sessionState.prune.messages.blocksById.entries()).map(
([blockId, block]) => [String(blockId), block],
),
),
activeBlockIds: Array.from(sessionState.prune.messages.activeBlockIds),
activeByAnchorMessageId: Object.fromEntries(
sessionState.prune.messages.activeByAnchorMessageId,
),
nextBlockId: sessionState.prune.messages.nextBlockId,
},
},
compressSummaries: sessionState.compressSummaries,
nudges: {
contextLimitAnchors: Array.from(sessionState.nudges.contextLimitAnchors),
turnNudgeAnchors: Array.from(sessionState.nudges.turnNudgeAnchors),
Expand Down Expand Up @@ -112,13 +126,14 @@ export async function loadSessionState(
const content = await fs.readFile(filePath, "utf-8")
const state = JSON.parse(content) as PersistedSessionState

const hasNewFormat = state?.prune?.tools && typeof state.prune.tools === "object"
const hasLegacyFormat = Array.isArray(state?.prune?.toolIds)
const hasPruneTools = state?.prune?.tools && typeof state.prune.tools === "object"
const hasPruneMessages = state?.prune?.messages && typeof state.prune.messages === "object"
const hasNudgeFormat = state?.nudges && typeof state.nudges === "object"
if (
!state ||
!state.prune ||
(!hasNewFormat && !hasLegacyFormat) ||
!hasPruneTools ||
!hasPruneMessages ||
!state.stats ||
!hasNudgeFormat
) {
Expand All @@ -128,27 +143,6 @@ export async function loadSessionState(
return null
}

if (Array.isArray(state.compressSummaries)) {
const validSummaries = state.compressSummaries.filter(
(s): s is CompressSummary =>
s !== null &&
typeof s === "object" &&
typeof s.blockId === "number" &&
typeof s.anchorMessageId === "string" &&
typeof s.summary === "string",
)
if (validSummaries.length !== state.compressSummaries.length) {
logger.warn("Filtered out malformed compressSummaries entries", {
sessionId: sessionId,
original: state.compressSummaries.length,
valid: validSummaries.length,
})
}
state.compressSummaries = validSummaries
} else {
state.compressSummaries = []
}

const rawContextLimitAnchors = Array.isArray(state.nudges.contextLimitAnchors)
? state.nudges.contextLimitAnchors
: []
Expand Down Expand Up @@ -244,10 +238,10 @@ export async function loadAllSessionStats(logger: Logger): Promise<AggregatedSta
result.totalTokens += state.stats.totalPruneTokens
result.totalTools += state.prune.tools
? Object.keys(state.prune.tools).length
: (state.prune.toolIds?.length ?? 0)
result.totalMessages += state.prune.messages
? Object.keys(state.prune.messages).length
: (state.prune.messageIds?.length ?? 0)
: 0
result.totalMessages += state.prune.messages?.byMessageId
? Object.keys(state.prune.messages.byMessageId).length
: 0
result.sessionCount++
}
} catch {
Expand Down
13 changes: 6 additions & 7 deletions lib/state/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import {
findLastCompactionTimestamp,
countTurns,
resetOnCompaction,
createPruneMessagesState,
loadPruneMessagesState,
loadPruneMap,
collectTurnNudgeAnchors,
} from "./utils"
Expand Down Expand Up @@ -67,9 +69,8 @@ export function createSessionState(): SessionState {
pendingManualTrigger: null,
prune: {
tools: new Map<string, number>(),
messages: new Map<string, number>(),
messages: createPruneMessagesState(),
},
compressSummaries: [],
nudges: {
contextLimitAnchors: new Set<string>(),
turnNudgeAnchors: new Set<string>(),
Expand Down Expand Up @@ -101,9 +102,8 @@ export function resetSessionState(state: SessionState): void {
state.pendingManualTrigger = null
state.prune = {
tools: new Map<string, number>(),
messages: new Map<string, number>(),
messages: createPruneMessagesState(),
}
state.compressSummaries = []
state.nudges = {
contextLimitAnchors: new Set<string>(),
turnNudgeAnchors: new Set<string>(),
Expand Down Expand Up @@ -159,9 +159,8 @@ export async function ensureSessionInitialized(
return
}

state.prune.tools = loadPruneMap(persisted.prune.tools, persisted.prune.toolIds)
state.prune.messages = loadPruneMap(persisted.prune.messages, persisted.prune.messageIds)
state.compressSummaries = persisted.compressSummaries || []
state.prune.tools = loadPruneMap(persisted.prune.tools)
state.prune.messages = loadPruneMessagesState(persisted.prune.messages)
state.nudges.contextLimitAnchors = new Set<string>(persisted.nudges.contextLimitAnchors || [])
state.nudges.turnNudgeAnchors = new Set<string>([
...state.nudges.turnNudgeAnchors,
Expand Down
Loading