Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2297,10 +2297,10 @@ def test_convert_tool_call_parts(self):
)


class TestAgentMetadata:
"""Unit tests for the AgentMetadata class."""
class TestAgentInfo:
"""Unit tests for the AgentInfo class."""

def test_agent_metadata_creation(self):
def test_agent_info_creation(self):
tool = genai_types.Tool(
function_declarations=[
genai_types.FunctionDeclaration(
Expand All @@ -2313,18 +2313,16 @@ def test_agent_metadata_creation(self):
)
]
)
agent_metadata = vertexai_genai_types.AgentMetadata(
agent_info = vertexai_genai_types.AgentInfo(
name="agent1",
instruction="instruction1",
description="description1",
tool_declarations=[tool],
sub_agent_names=["sub_agent1"],
)
assert agent_metadata.name == "agent1"
assert agent_metadata.instruction == "instruction1"
assert agent_metadata.description == "description1"
assert agent_metadata.tool_declarations == [tool]
assert agent_metadata.sub_agent_names == ["sub_agent1"]
assert agent_info.name == "agent1"
assert agent_info.instruction == "instruction1"
assert agent_info.description == "description1"
assert agent_info.tool_declarations == [tool]


class TestEvent:
Expand Down Expand Up @@ -2359,13 +2357,11 @@ def test_eval_case_with_agent_eval_fields(self):
)
]
)
agent_metadata = {
"agent1": vertexai_genai_types.AgentMetadata(
name="agent1",
instruction="instruction1",
tool_declarations=[tool],
)
}
agent_info = vertexai_genai_types.AgentInfo(
name="agent1",
instruction="instruction1",
tool_declarations=[tool],
)
intermediate_events = [
vertexai_genai_types.Event(
event_id="event1",
Expand All @@ -2381,14 +2377,26 @@ def test_eval_case_with_agent_eval_fields(self):
response=genai_types.Content(parts=[genai_types.Part(text="Hi")])
)
],
agent_metadata=agent_metadata,
agent_info=agent_info,
intermediate_events=intermediate_events,
)

assert eval_case.agent_metadata == agent_metadata
assert eval_case.agent_info == agent_info
assert eval_case.intermediate_events == intermediate_events


class TestSessionInput:
"""Unit tests for the SessionInput class."""

def test_session_input_creation(self):
session_input = vertexai_genai_types.SessionInput(
user_id="user1",
state={"key": "value"},
)
assert session_input.user_id == "user1"
assert session_input.state == {"key": "value"}


class TestMetric:
"""Unit tests for the Metric class."""

Expand Down
56 changes: 39 additions & 17 deletions vertexai/_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10398,8 +10398,8 @@ class EvalRunInferenceConfigDict(TypedDict, total=False):
EvalRunInferenceConfigOrDict = Union[EvalRunInferenceConfig, EvalRunInferenceConfigDict]


class AgentMetadata(_common.BaseModel):
"""AgentMetadata for agent eval."""
class AgentInfo(_common.BaseModel):
"""The agent info of an agent, used for agent eval."""

name: Optional[str] = Field(
default=None, description="""Agent name, used as an identifier."""
Expand All @@ -10413,13 +10413,10 @@ class AgentMetadata(_common.BaseModel):
tool_declarations: Optional[genai_types.ToolListUnion] = Field(
default=None, description="""List of tools used by the Agent."""
)
sub_agent_names: Optional[list[str]] = Field(
default=None, description="""List of sub-agent names."""
)


class AgentMetadataDict(TypedDict, total=False):
"""AgentMetadata for agent eval."""
class AgentInfoDict(TypedDict, total=False):
"""The agent info of an agent, used for agent eval."""

name: Optional[str]
"""Agent name, used as an identifier."""
Expand All @@ -10433,11 +10430,8 @@ class AgentMetadataDict(TypedDict, total=False):
tool_declarations: Optional[genai_types.ToolListUnionDict]
"""List of tools used by the Agent."""

sub_agent_names: Optional[list[str]]
"""List of sub-agent names."""


AgentMetadataOrDict = Union[AgentMetadata, AgentMetadataDict]
AgentInfoOrDict = Union[AgentInfo, AgentInfoDict]


class ContentMapContents(_common.BaseModel):
Expand Down Expand Up @@ -10669,11 +10663,11 @@ class EvalCase(_common.BaseModel):
)
intermediate_events: Optional[list[Event]] = Field(
default=None,
description="""Intermediate events of a single turn in agent eval or intermediate events of the last turn for multi-turn agent eval.""",
description="""This field is experimental and may change in future versions. Intermediate events of a single turn in an agent run or intermediate events of the last turn for multi-turn an agent run.""",
)
agent_metadata: Optional[dict[str, AgentMetadata]] = Field(
agent_info: Optional[AgentInfo] = Field(
default=None,
description="""Agent metadata for agent eval, keyed by agent name. This can be extended for multi-agent evaluation.""",
description="""This field is experimental and may change in future versions. The agent info of the agent under evaluation. This can be extended for multi-agent evaluation.""",
)
# Allow extra fields to support custom metric prompts and stay backward compatible.
model_config = ConfigDict(frozen=True, extra="allow")
Expand Down Expand Up @@ -10704,10 +10698,10 @@ class EvalCaseDict(TypedDict, total=False):
"""Unique identifier for the evaluation case."""

intermediate_events: Optional[list[EventDict]]
"""Intermediate events of a single turn in agent eval or intermediate events of the last turn for multi-turn agent eval."""
"""This field is experimental and may change in future versions. Intermediate events of a single turn in an agent run or intermediate events of the last turn for multi-turn an agent run."""

agent_metadata: Optional[dict[str, AgentMetadataDict]]
"""Agent metadata for agent eval, keyed by agent name. This can be extended for multi-agent evaluation."""
agent_info: Optional[AgentInfoDict]
"""This field is experimental and may change in future versions. The agent info of the agent under evaluation. This can be extended for multi-agent evaluation."""


EvalCaseOrDict = Union[EvalCase, EvalCaseDict]
Expand Down Expand Up @@ -11076,6 +11070,34 @@ class EvaluationResultDict(TypedDict, total=False):
EvaluationResultOrDict = Union[EvaluationResult, EvaluationResultDict]


class SessionInput(_common.BaseModel):
"""This field is experimental and may change in future versions.

Input to initialize a session and run an agent, used for agent evaluation.
"""

user_id: Optional[str] = Field(default=None, description="""The user id.""")
state: Optional[dict[str, str]] = Field(
default=None, description="""The state of the session."""
)


class SessionInputDict(TypedDict, total=False):
"""This field is experimental and may change in future versions.

Input to initialize a session and run an agent, used for agent evaluation.
"""

user_id: Optional[str]
"""The user id."""

state: Optional[dict[str, str]]
"""The state of the session."""


SessionInputOrDict = Union[SessionInput, SessionInputDict]


class WinRateStats(_common.BaseModel):
"""Statistics for win rates for a single metric."""

Expand Down
Loading