Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2075,10 +2075,10 @@ impl Session {
}

pub(crate) async fn recompute_token_usage(&self, turn_context: &TurnContext) {
let Some(estimated_total_tokens) = self
.clone_history()
.await
.estimate_token_count(turn_context)
let history = self.clone_history().await;
let base_instructions = self.get_base_instructions().await;
let Some(estimated_total_tokens) =
history.estimate_token_count_with_base_instructions(&base_instructions)
else {
return;
};
Expand Down Expand Up @@ -4734,6 +4734,7 @@ mod tests {
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_app_server_protocol::AppInfo;
use codex_app_server_protocol::AuthMode;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use std::path::Path;
Expand Down Expand Up @@ -5013,6 +5014,46 @@ mod tests {
assert_eq!(actual, Some(info2));
}

#[tokio::test]
async fn recompute_token_usage_uses_session_base_instructions() {
let (session, turn_context) = make_session_and_context().await;

let override_instructions = "SESSION_OVERRIDE_INSTRUCTIONS_ONLY".repeat(120);
{
let mut state = session.state.lock().await;
state.session_configuration.base_instructions = override_instructions.clone();
}

let item = user_message("hello");
session
.record_into_history(std::slice::from_ref(&item), &turn_context)
.await;

let history = session.clone_history().await;
let session_base_instructions = BaseInstructions {
text: override_instructions,
};
let expected_tokens = history
.estimate_token_count_with_base_instructions(&session_base_instructions)
.expect("estimate with session base instructions");
let model_estimated_tokens = history
.estimate_token_count(&turn_context)
.expect("estimate with model instructions");
assert_ne!(expected_tokens, model_estimated_tokens);

session.recompute_token_usage(&turn_context).await;

let actual_tokens = session
.state
.lock()
.await
.token_info()
.expect("token info")
.last_token_usage
.total_tokens;
assert_eq!(actual_tokens, expected_tokens.max(0));
}

#[tokio::test]
async fn record_initial_history_reconstructs_forked_transcript() {
let (session, turn_context) = make_session_and_context().await;
Expand Down
14 changes: 10 additions & 4 deletions codex-rs/core/src/compact_remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::protocol::RolloutItem;
use crate::protocol::TurnStartedEvent;
use codex_protocol::items::ContextCompactionItem;
use codex_protocol::items::TurnItem;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ResponseItem;
use tracing::info;

Expand Down Expand Up @@ -49,8 +50,12 @@ async fn run_remote_compact_task_inner_impl(
sess.emit_turn_item_started(turn_context, &compaction_item)
.await;
let mut history = sess.clone_history().await;
let deleted_items =
trim_function_call_history_to_fit_context_window(&mut history, turn_context.as_ref());
let base_instructions = sess.get_base_instructions().await;
let deleted_items = trim_function_call_history_to_fit_context_window(
&mut history,
turn_context.as_ref(),
&base_instructions,
);
if deleted_items > 0 {
info!(
turn_id = %turn_context.sub_id,
Expand All @@ -71,7 +76,7 @@ async fn run_remote_compact_task_inner_impl(
input: history.for_prompt(),
tools: vec![],
parallel_tool_calls: false,
base_instructions: sess.get_base_instructions().await,
base_instructions,
personality: turn_context.personality,
output_schema: None,
};
Expand Down Expand Up @@ -102,14 +107,15 @@ async fn run_remote_compact_task_inner_impl(
fn trim_function_call_history_to_fit_context_window(
history: &mut ContextManager,
turn_context: &TurnContext,
base_instructions: &BaseInstructions,
) -> usize {
let mut deleted_items = 0usize;
let Some(context_window) = turn_context.model_context_window() else {
return deleted_items;
};

while history
.estimate_token_count(turn_context)
.estimate_token_count_with_base_instructions(base_instructions)
.is_some_and(|estimated_tokens| estimated_tokens > context_window)
{
let Some(last_item) = history.raw_items().last() else {
Expand Down
15 changes: 13 additions & 2 deletions codex-rs/core/src/context_manager/history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::truncate::approx_tokens_from_byte_count;
use crate::truncate::truncate_function_output_items_with_policy;
use crate::truncate::truncate_text;
use crate::user_shell_command::is_user_shell_command_text;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputContentItem;
Expand Down Expand Up @@ -88,8 +89,18 @@ impl ContextManager {
pub(crate) fn estimate_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
let model_info = &turn_context.model_info;
let personality = turn_context.personality.or(turn_context.config.personality);
let base_instructions = model_info.get_model_instructions(personality);
let base_tokens = i64::try_from(approx_token_count(&base_instructions)).unwrap_or(i64::MAX);
let base_instructions = BaseInstructions {
text: model_info.get_model_instructions(personality),
};
self.estimate_token_count_with_base_instructions(&base_instructions)
}

pub(crate) fn estimate_token_count_with_base_instructions(
&self,
base_instructions: &BaseInstructions,
) -> Option<i64> {
let base_tokens =
i64::try_from(approx_token_count(&base_instructions.text)).unwrap_or(i64::MAX);

let items_tokens = self.items.iter().fold(0i64, |acc, item| {
acc.saturating_add(estimate_item_token_count(item))
Expand Down
27 changes: 27 additions & 0 deletions codex-rs/core/src/context_manager/history_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::*;
use crate::truncate;
use crate::truncate::TruncationPolicy;
use codex_git::GhostCommit;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputContentItem;
Expand Down Expand Up @@ -103,6 +104,10 @@ fn truncate_exec_output(content: &str) -> String {
truncate::truncate_text(content, TruncationPolicy::Tokens(EXEC_FORMAT_MAX_TOKENS))
}

fn approx_token_count_for_text(text: &str) -> i64 {
i64::try_from(text.len().saturating_add(3) / 4).unwrap_or(i64::MAX)
}

#[test]
fn filters_non_api_messages() {
let mut h = ContextManager::default();
Expand Down Expand Up @@ -250,6 +255,28 @@ fn get_history_for_prompt_drops_ghost_commits() {
assert_eq!(filtered, vec![]);
}

#[test]
fn estimate_token_count_with_base_instructions_uses_provided_text() {
let history = create_history_with_items(vec![assistant_msg("hello from history")]);
let short_base = BaseInstructions {
text: "short".to_string(),
};
let long_base = BaseInstructions {
text: "x".repeat(1_000),
};

let short_estimate = history
.estimate_token_count_with_base_instructions(&short_base)
.expect("token estimate");
let long_estimate = history
.estimate_token_count_with_base_instructions(&long_base)
.expect("token estimate");

let expected_delta = approx_token_count_for_text(&long_base.text)
- approx_token_count_for_text(&short_base.text);
assert_eq!(long_estimate - short_estimate, expected_delta);
}

#[test]
fn remove_first_item_removes_matching_output_for_function_call() {
let items = vec![
Expand Down
Loading
Loading