Skip to content

Commit 9d8342b

Browse files
committed
process cache_control in kwargs
1 parent 5187817 commit 9d8342b

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

libs/partners/anthropic/langchain_anthropic/chat_models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,26 @@ def _get_request_payload(
12911291
) -> dict:
12921292
messages = self._convert_input(input_).to_messages()
12931293
system, formatted_messages = _format_messages(messages)
1294+
1295+
# If cache_control is provided in kwargs, add it to last message
1296+
# and content block.
1297+
if "cache_control" in kwargs and formatted_messages:
1298+
if isinstance(formatted_messages[-1]["content"], list):
1299+
formatted_messages[-1]["content"][-1]["cache_control"] = kwargs.pop(
1300+
"cache_control"
1301+
)
1302+
elif isinstance(formatted_messages[-1]["content"], str):
1303+
formatted_messages[-1]["content"] = [
1304+
{
1305+
"type": "text",
1306+
"text": formatted_messages[-1]["content"],
1307+
"cache_control": kwargs.pop("cache_control"),
1308+
}
1309+
]
1310+
else:
1311+
pass
1312+
_ = kwargs.pop("cache_control", None)
1313+
12941314
payload = {
12951315
"model": self.model,
12961316
"max_tokens": self.max_tokens,

libs/partners/anthropic/tests/unit_tests/test_chat_models.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,3 +1056,50 @@ def mock_create(*args: Any, **kwargs: Any) -> Message:
10561056
# Test headers are correctly propagated to request
10571057
payload = llm._get_request_payload([input_message])
10581058
assert payload["mcp_servers"][0]["authorization_token"] == "PLACEHOLDER"
1059+
1060+
1061+
def test_cache_control_kwarg() -> None:
1062+
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
1063+
1064+
messages = [HumanMessage("foo"), AIMessage("bar"), HumanMessage("baz")]
1065+
payload = llm._get_request_payload(messages)
1066+
assert payload["messages"] == [
1067+
{"role": "user", "content": "foo"},
1068+
{"role": "assistant", "content": "bar"},
1069+
{"role": "user", "content": "baz"},
1070+
]
1071+
1072+
payload = llm._get_request_payload(messages, cache_control={"type": "ephemeral"})
1073+
assert payload["messages"] == [
1074+
{"role": "user", "content": "foo"},
1075+
{"role": "assistant", "content": "bar"},
1076+
{
1077+
"role": "user",
1078+
"content": [
1079+
{"type": "text", "text": "baz", "cache_control": {"type": "ephemeral"}}
1080+
],
1081+
},
1082+
]
1083+
1084+
messages = [
1085+
HumanMessage("foo"),
1086+
AIMessage("bar"),
1087+
HumanMessage(
1088+
content=[
1089+
{"type": "text", "text": "baz"},
1090+
{"type": "text", "text": "qux"},
1091+
]
1092+
),
1093+
]
1094+
payload = llm._get_request_payload(messages, cache_control={"type": "ephemeral"})
1095+
assert payload["messages"] == [
1096+
{"role": "user", "content": "foo"},
1097+
{"role": "assistant", "content": "bar"},
1098+
{
1099+
"role": "user",
1100+
"content": [
1101+
{"type": "text", "text": "baz"},
1102+
{"type": "text", "text": "qux", "cache_control": {"type": "ephemeral"}},
1103+
],
1104+
},
1105+
]

0 commit comments

Comments
 (0)