Skip to content

Commit 1c0dcbd

Browse files
feat: support extra headers for model (#103)
1 parent 4f4c3c0 commit 1c0dcbd

File tree

4 files changed

+32
-6
lines changed

4 files changed

+32
-6
lines changed

tests/test_agent.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from google.adk.tools import load_memory
1616

1717
from veadk import Agent
18+
from veadk.consts import DEFAULT_MODEL_EXTRA_HEADERS
1819
from veadk.knowledgebase import KnowledgeBase
1920
from veadk.memory.long_term_memory import LongTermMemory
2021
from veadk.tools import load_knowledgebase_tool
@@ -26,11 +27,14 @@ def test_agent():
2627
long_term_memory = LongTermMemory(backend="local")
2728
tracer = OpentelemetryTracer()
2829

30+
model_extra_headers = {"test-header": "test-value"}
31+
2932
agent = Agent(
3033
model_name="test_model_name",
3134
model_provider="test_model_provider",
3235
model_api_key="test_model_api_key",
3336
model_api_base="test_model_api_base",
37+
model_extra_headers=model_extra_headers,
3438
tools=[],
3539
sub_agents=[],
3640
knowledgebase=knowledgebase,
@@ -39,7 +43,10 @@ def test_agent():
3943
serve_url="",
4044
)
4145

46+
model_extra_headers |= DEFAULT_MODEL_EXTRA_HEADERS
47+
4248
assert agent.model.model == f"{agent.model_provider}/{agent.model_name}"
49+
assert agent.model_extra_headers == model_extra_headers
4350

4451
assert agent.knowledgebase == knowledgebase
4552
assert agent.knowledgebase.backend == "local"

veadk/agent.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
DEFALUT_MODEL_AGENT_PROVIDER,
3232
DEFAULT_MODEL_AGENT_API_BASE,
3333
DEFAULT_MODEL_AGENT_NAME,
34+
DEFAULT_MODEL_EXTRA_HEADERS,
3435
)
3536
from veadk.evaluation import EvalSetRecorder
3637
from veadk.knowledgebase import KnowledgeBase
@@ -73,6 +74,9 @@ class Agent(LlmAgent):
7374
model_api_key: str = Field(default_factory=lambda: getenv("MODEL_AGENT_API_KEY"))
7475
"""The api key of the model for agent running."""
7576

77+
model_extra_headers: dict = Field(default_factory=dict)
78+
"""The extra headers to include in the model requests."""
79+
7680
tools: list[ToolUnion] = []
7781
"""The tools provided to agent."""
7882

@@ -96,11 +100,23 @@ class Agent(LlmAgent):
96100

97101
def model_post_init(self, __context: Any) -> None:
98102
super().model_post_init(None) # for sub_agents init
99-
self.model = LiteLlm(
100-
model=f"{self.model_provider}/{self.model_name}",
101-
api_key=self.model_api_key,
102-
api_base=self.model_api_base,
103-
)
103+
104+
self.model_extra_headers |= DEFAULT_MODEL_EXTRA_HEADERS
105+
106+
if not self.model:
107+
self.model = LiteLlm(
108+
model=f"{self.model_provider}/{self.model_name}",
109+
api_key=self.model_api_key,
110+
api_base=self.model_api_base,
111+
extra_headers=self.model_extra_headers,
112+
)
113+
logger.debug(
114+
f"LiteLLM client created with extra headers: {self.model_extra_headers}"
115+
)
116+
else:
117+
logger.warning(
118+
"You are trying to use your own LiteLLM client, some default request headers may be missing."
119+
)
104120

105121
if self.knowledgebase:
106122
from veadk.tools import load_knowledgebase_tool

veadk/consts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from veadk.version import VERSION
16+
1517
DEFAULT_MODEL_AGENT_NAME = "doubao-seed-1-6-250615"
1618
DEFALUT_MODEL_AGENT_PROVIDER = "openai"
1719
DEFAULT_MODEL_AGENT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"
20+
DEFAULT_MODEL_EXTRA_HEADERS = {"veadk-source": "veadk", "veadk-version": VERSION}

veadk/tracing/telemetry/telemetry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def set_common_attributes(
6464
)
6565
return
6666

67-
if isinstance(invocation_context.agent, Agent):
67+
if isinstance(invocation_context.agent, Agent) and invocation_context.agent.tracers:
6868
try:
6969
from veadk.tracing.telemetry.opentelemetry_tracer import OpentelemetryTracer
7070

0 commit comments

Comments
 (0)