Skip to content

Commit 0901f11

Browse files
authored
community: add truncation params when an openai assistant's run is created (#28158)
**Description:** When an OpenAI assistant is invoked, it creates a run by default, allowing users to set only a few request fields. The truncation strategy is set to auto, which includes previous messages in the thread along with the current question until the context length is reached. This causes token usage to grow incrementally: consumed_tokens = previous_consumed_tokens + current_consumed_tokens. This PR adds support for user-defined truncation strategies, giving better control over token consumption. **Issue:** High token consumption.
1 parent c09000f commit 0901f11

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

libs/community/langchain_community/agents/openai_assistant/base.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -543,11 +543,16 @@ def _create_run(self, input: dict) -> Any:
543543
Returns:
544544
Any: The created run object.
545545
"""
546-
params = {
547-
k: v
548-
for k, v in input.items()
549-
if k in ("instructions", "model", "tools", "tool_resources", "run_metadata")
550-
}
546+
allowed_assistant_params = (
547+
"instructions",
548+
"model",
549+
"tools",
550+
"tool_resources",
551+
"run_metadata",
552+
"truncation_strategy",
553+
"max_prompt_tokens",
554+
)
555+
params = {k: v for k, v in input.items() if k in allowed_assistant_params}
551556
return self.client.beta.threads.runs.create(
552557
input["thread_id"],
553558
assistant_id=self.assistant_id,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Any
2+
from unittest.mock import AsyncMock, MagicMock
3+
4+
import pytest
5+
6+
from langchain_community.agents.openai_assistant import OpenAIAssistantV2Runnable
7+
8+
9+
def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any:
10+
client = AsyncMock() if use_async else MagicMock()
11+
client.beta.threads.runs.create = MagicMock(return_value=None) # type: ignore
12+
return client
13+
14+
15+
@pytest.mark.requires("openai")
16+
def test_set_run_truncation_params() -> None:
17+
client = _create_mock_client()
18+
19+
assistant = OpenAIAssistantV2Runnable(assistant_id="assistant_xyz", client=client)
20+
input = {
21+
"content": "AI question",
22+
"thread_id": "thread_xyz",
23+
"instructions": "You're a helpful assistant; answer questions as best you can.",
24+
"model": "gpt-4o",
25+
"max_prompt_tokens": 2000,
26+
"truncation_strategy": {"type": "last_messages", "last_messages": 10},
27+
}
28+
expected_response = {
29+
"assistant_id": "assistant_xyz",
30+
"instructions": "You're a helpful assistant; answer questions as best you can.",
31+
"model": "gpt-4o",
32+
"max_prompt_tokens": 2000,
33+
"truncation_strategy": {"type": "last_messages", "last_messages": 10},
34+
}
35+
36+
assistant._create_run(input=input)
37+
_, kwargs = client.beta.threads.runs.create.call_args
38+
39+
assert kwargs == expected_response

0 commit comments

Comments
 (0)