Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,16 @@ def get_attributes_from_agent_collaborator_invocation_input(
# Create message
messages = [Message(content=content, role="assistant")]

# Create metadata
metadata = {
"invocation_type": "agent_collaborator_invocation",
"agent_collaborator_name": collaborator_input.get("agentCollaboratorName"),
"agent_collaborator_alias_arn": collaborator_input.get("agentCollaboratorAliasArn"),
"input_type": input_type,
}
# Create metadata - merge invocation metadata with agent-specific fields
metadata = cls.get_metadata_attributes(collaborator_input.get("metadata", {}))
metadata.update(
{
"invocation_type": "agent_collaborator_invocation",
"agent_collaborator_name": collaborator_input.get("agentCollaboratorName"),
"agent_collaborator_alias_arn": collaborator_input.get("agentCollaboratorAliasArn"),
"input_type": input_type,
}
)

return {
**get_span_kind_attributes(OpenInferenceSpanKindValues.AGENT),
Expand Down Expand Up @@ -444,12 +447,17 @@ def get_attributes_from_agent_collaborator_invocation_output(
# Create message
messages = [Message(role="assistant", content=output_value)]

# Create metadata
metadata = {
"agent_collaborator_name": collaborator_output.get("agentCollaboratorName"),
"agent_collaborator_alias_arn": collaborator_output.get("agentCollaboratorAliasArn"),
"output_type": output_type,
}
# Create metadata - merge observation metadata with agent-specific fields
metadata = cls.get_metadata_attributes(collaborator_output.get("metadata", {}))
metadata.update(
{
"agent_collaborator_name": collaborator_output.get("agentCollaboratorName"),
"agent_collaborator_alias_arn": collaborator_output.get(
"agentCollaboratorAliasArn"
),
"output_type": output_type,
}
)

return {
**get_output_attributes(output_value),
Expand Down Expand Up @@ -542,6 +550,7 @@ def get_event_type(cls, trace_data: dict[str, Any]) -> str:
"guardrailTrace",
"postProcessingTrace",
"failureTrace",
"routingClassifierTrace",
]
for trace_event in trace_events:
if trace_event in trace_data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,30 @@ def _set_parent_span_input_attributes(
if not isinstance(trace_span, TraceNode):
return

# First check the node's own chunks (for agent-collaborator nodes)
for trace_data in trace_span.chunks:
trace_event = AttributeExtractor.get_event_type(trace_data)
event_data = trace_data.get(trace_event, {})

# Extract from invocation input in node chunks
if "invocationInput" in event_data:
invocation_input = event_data.get("invocationInput", {})
# For agent-collaborator nodes, get full attributes including LLM messages
if trace_span.node_type == "agent-collaborator":
attrs = AttributeExtractor.get_attributes_from_invocation_input(
invocation_input
)
else:
attrs = AttributeExtractor.get_parent_input_attributes_from_invocation_input(
invocation_input
)
if attrs:
input_attributes.update(attrs)
input_attributes.update(attributes.request_attributes)
attributes.request_attributes = input_attributes
return

# Then check spans
for span in trace_span.spans:
for trace_data in span.chunks:
trace_event = AttributeExtractor.get_event_type(trace_data)
Expand All @@ -341,10 +365,11 @@ def _set_parent_span_input_attributes(
attrs = AttributeExtractor.get_parent_input_attributes_from_invocation_input(
invocation_input
)
input_attributes.update(attrs)
input_attributes.update(attributes.request_attributes)
attributes.request_attributes = input_attributes
return
if attrs:
input_attributes.update(attrs)
input_attributes.update(attributes.request_attributes)
attributes.request_attributes = input_attributes
return

# Recursively check nested nodes
if isinstance(span, TraceNode):
Expand All @@ -363,6 +388,26 @@ def _set_parent_span_output_attributes(
if not isinstance(trace_span, TraceNode):
return

# First check the node's own chunks (for agent-collaborator nodes)
for trace_data in trace_span.chunks[::-1]:
trace_event = AttributeExtractor.get_event_type(trace_data)
event_data = trace_data.get(trace_event, {})

# Extract from observation in node chunks
if "observation" in event_data:
observation = event_data.get("observation", {})
# For agent-collaborator nodes, get full output attributes including LLM messages
if trace_span.node_type == "agent-collaborator":
attrs = AttributeExtractor.get_attributes_from_observation(observation)
if attrs:
attributes.request_attributes.update(attrs)
return
elif final_response := observation.get("finalResponse"):
if text := final_response.get("text", ""):
attributes.request_attributes.update(get_output_attributes(text))
return

# Then check spans
for span in trace_span.spans[::-1]:
for trace_data in span.chunks[::-1]:
trace_event = AttributeExtractor.get_event_type(trace_data)
Expand All @@ -382,6 +427,30 @@ def _set_parent_span_output_attributes(
return attributes.request_attributes.update(
get_output_attributes(output_text)
)
# For Routing classifier events, the output is in rawResponse.content
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does input data/ metadata data show up correctly for routing classifier traces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call on raw_response_content, I'll change the name to something a little more clear.

As for the routingClassifierTrace, it does have the input/output data from the LLM call, but the metadata isn't captured. This is due to how the TraceNodes are handled, the function loops through the chunks of the trace_span_data and uses the events in the chunks to populate the span attributes, but it also sets the name of the span. If we add the trace data as a chunk to the TraceNode then _prepare_span_attributes will see the modelInvocationInput or modelInvocationOutput attributes in the trace data which will cause it to reassign the name to LLM instead of routingClassifierTrace or orchestrationTrace. We could probably get around that issue though by adding a _process_trace_node method that only extracts the metadata from the chunks without reassigning the span name

if parsed_response == {}:
raw_response = model_invocation_output.get("rawResponse", {})
if raw_response_content := raw_response.get("content"):
try:
response_content_json = json.loads(raw_response_content)
if (
output_content := response_content_json.get("output", {})
.get("message", {})
.get("content")
):
return attributes.request_attributes.update(
get_output_attributes(output_content)
)
# Return full parsed json if output isn't found
return attributes.request_attributes.update(
get_output_attributes(response_content_json)
)
except Exception:
pass
# Fallback to raw response if content was not valid JSON
return attributes.request_attributes.update(
get_output_attributes(raw_response_content)
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Routing Classifier JSON Handling Issue

In the routing classifier logic, raw_response_content is reassigned from its original string to a parsed JSON object. If the expected nested output isn't found, the fallback uses this JSON object, which get_output_attributes() likely expects as a string.

Fix in Cursor Fix in Web

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luke-moehlenbrock I think this is fine given what get_output_attributes does, but I think it's valid feedback that raw_response_content should ideally not be re-assigned, so that it's easier to follow and so the the data takes on something other than what it's name is describing. I'd rename 390 to something else and then use that in a return statement within the try block if the subsequent isn't true (i.e. no 'output" on the json)


# Extract from invocation input
if "observation" in event_data:
Expand Down Expand Up @@ -468,7 +537,13 @@ def _prepare_span_attributes(cls, trace_span_data: Union[TraceSpan, TraceNode])

# Set name from node type if it's a TraceNode
if isinstance(trace_span_data, TraceNode):
_attributes.name = trace_span_data.node_type
if trace_span_data.node_type == "guardrailTrace":
pre_or_post = trace_span_data.node_trace_id.split("-")[-1]
_attributes.name = pre_or_post + "GuardrailTrace"
else:
_attributes.name = trace_span_data.node_type
cls._process_trace_node(trace_span_data, _attributes)
return _attributes

# Process each chunk in the trace span
for trace_data in trace_span_data.chunks:
Expand Down Expand Up @@ -502,6 +577,57 @@ def _prepare_span_attributes(cls, trace_span_data: Union[TraceSpan, TraceNode])
cls._process_failure_trace(event_data, _attributes)
return _attributes

@classmethod
def _process_trace_node(cls, trace_node: TraceNode, attributes: _Attributes) -> None:
"""
Process trace node data. This extracts metadata attributes and adds them
to the trace node span. This includes routingClassifierTrace, orchestrationTrace,
guardrailTrace, etc.
"""
for trace_data in trace_node.chunks:
trace_event = AttributeExtractor.get_event_type(trace_data)
event_data = trace_data.get(trace_event, {})

# Extract agent collaborator name for agent-collaborator nodes
if trace_node.node_type == "agent-collaborator" and "invocationInput" in event_data:
invocation_input = event_data.get("invocationInput", {})
if "agentCollaboratorInvocationInput" in invocation_input:
agent_collaborator_name = invocation_input.get(
"agentCollaboratorInvocationInput", {}
).get("agentCollaboratorName", "")
invocation_type = invocation_input.get("invocationType", "")
if agent_collaborator_name:
attributes.name = f"{invocation_type.lower()}[{agent_collaborator_name}]"

# Extract child-level metadata first (will be overridden by trace-level metadata)
if "modelInvocationOutput" in event_data:
model_invocation_output = event_data.get("modelInvocationOutput", {})
attributes.metadata.update(
AttributeExtractor.get_metadata_attributes(
model_invocation_output.get("metadata")
)
)
if observation := event_data.get("observation"):
# For agent-collaborator nodes, extract metadata from the observation itself
if trace_node.node_type == "agent-collaborator":
if observation_metadata := observation.get("metadata"):
attributes.metadata.update(
AttributeExtractor.get_observation_metadata_attributes(
observation_metadata
)
)
# For other nodes, extract from finalResponse if present
if final_response := observation.get("finalResponse"):
if final_response_metadata := final_response.get("metadata"):
attributes.metadata.update(
AttributeExtractor.get_metadata_attributes(final_response_metadata)
)

# Extract trace-level metadata last so it takes precedence
# (for orchestrationTrace, guardrailTrace, etc.)
if metadata := event_data.get("metadata"):
attributes.metadata.update(AttributeExtractor.get_metadata_attributes(metadata))

@classmethod
def _process_model_invocation_input(
cls, event_data: Dict[str, Any], attributes: _Attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def generate_unique_trace_id(event_type: str, trace_id: str) -> str:
Returns:
A unique trace ID string
"""
if "guardrail" in trace_id:
# if guardrail, use the first 7 parts of the trace id in order to differentiate
# between pre and post guardrail; it will look something like this:
# 4ce64021-13b2-23c5-9d70-aaefe8881138-guardrail-pre-0
return f"{event_type}_{'-'.join(trace_id.split('-')[:7])}"
return f"{event_type}_{'-'.join(trace_id.split('-')[:5])}"


Expand Down Expand Up @@ -210,6 +215,8 @@ def _handle_chunk_for_current_node(
trace_data: The trace data to add to the node or span
"""
if chunk_type not in ["invocationInput", "modelInvocationInput"]:
# Add chunk to trace node as well, useful for propogating metadata to the parent node
node.add_chunk(trace_data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Trace Data Duplication Causes Processing Errors

Potential duplicate chunk data: The code adds trace_data to node.chunks in _handle_chunk_for_current_node (line 219), but trace_data is also added to node.chunks when creating new trace nodes in _handle_new_trace_node (lines 257 and 263). This means the same trace_data could be added twice to the TraceNode's chunks list - once during node creation and once during chunk processing. This duplication could lead to incorrect metadata extraction or span attribute processing as the same data would be processed multiple times when iterating over node.chunks in _process_trace_node.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, the first chunk will have invocationInput or modelInvocationInput, this might not always be the case but it should be the case for most TraceNodes. In the case where we add a duplicate chunk to a node, it shouldn't have any impact. The function that processes span attributes loops through the chunks in a node and uses each chunk to write attributes to the _attributes object. For duplicate chunks, the second chunk would just overwrite the attributes with the exact same values so there shouldn't be any unexpected behaviour here. We need to add the chunks here to ensure that routingClassifierTrace and orchestrationTrace properly capture the output from their respective events.

if node.current_span:
node.current_span.add_chunk(trace_data)
else:
Expand Down Expand Up @@ -247,11 +254,13 @@ def _handle_new_trace_node(
trace_node.chunks.append(trace_data)
elif event_type == "guardrailTrace":
trace_node = TraceNode(node_trace_id, event_type)
trace_node.add_chunk(trace_data)
trace_span = TraceSpan(chunk_type)
trace_span.add_chunk(trace_data)
trace_node.add_span(trace_span)
else:
trace_node = TraceNode(node_trace_id, event_type)
trace_node.add_chunk(trace_data)
trace_span = TraceSpan(chunk_type)
trace_span.add_chunk(trace_data)
trace_span.parent_node = parent_node # This is child for the Agent Span
Expand Down
Loading