-
Couldn't load subscription status.
- Fork 148
feat(bedrock): Add Routing Classifier Trace event #2342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -317,6 +317,30 @@ def _set_parent_span_input_attributes( | |
| if not isinstance(trace_span, TraceNode): | ||
| return | ||
|
|
||
| # First check the node's own chunks (for agent-collaborator nodes) | ||
| for trace_data in trace_span.chunks: | ||
| trace_event = AttributeExtractor.get_event_type(trace_data) | ||
| event_data = trace_data.get(trace_event, {}) | ||
|
|
||
| # Extract from invocation input in node chunks | ||
| if "invocationInput" in event_data: | ||
| invocation_input = event_data.get("invocationInput", {}) | ||
| # For agent-collaborator nodes, get full attributes including LLM messages | ||
| if trace_span.node_type == "agent-collaborator": | ||
| attrs = AttributeExtractor.get_attributes_from_invocation_input( | ||
| invocation_input | ||
| ) | ||
| else: | ||
| attrs = AttributeExtractor.get_parent_input_attributes_from_invocation_input( | ||
| invocation_input | ||
| ) | ||
| if attrs: | ||
| input_attributes.update(attrs) | ||
| input_attributes.update(attributes.request_attributes) | ||
| attributes.request_attributes = input_attributes | ||
| return | ||
|
|
||
| # Then check spans | ||
| for span in trace_span.spans: | ||
| for trace_data in span.chunks: | ||
| trace_event = AttributeExtractor.get_event_type(trace_data) | ||
|
|
@@ -341,10 +365,11 @@ def _set_parent_span_input_attributes( | |
| attrs = AttributeExtractor.get_parent_input_attributes_from_invocation_input( | ||
| invocation_input | ||
| ) | ||
| input_attributes.update(attrs) | ||
| input_attributes.update(attributes.request_attributes) | ||
| attributes.request_attributes = input_attributes | ||
| return | ||
| if attrs: | ||
| input_attributes.update(attrs) | ||
| input_attributes.update(attributes.request_attributes) | ||
| attributes.request_attributes = input_attributes | ||
| return | ||
|
|
||
| # Recursively check nested nodes | ||
| if isinstance(span, TraceNode): | ||
|
|
@@ -363,6 +388,26 @@ def _set_parent_span_output_attributes( | |
| if not isinstance(trace_span, TraceNode): | ||
| return | ||
|
|
||
| # First check the node's own chunks (for agent-collaborator nodes) | ||
| for trace_data in trace_span.chunks[::-1]: | ||
| trace_event = AttributeExtractor.get_event_type(trace_data) | ||
| event_data = trace_data.get(trace_event, {}) | ||
|
|
||
| # Extract from observation in node chunks | ||
| if "observation" in event_data: | ||
| observation = event_data.get("observation", {}) | ||
| # For agent-collaborator nodes, get full output attributes including LLM messages | ||
| if trace_span.node_type == "agent-collaborator": | ||
| attrs = AttributeExtractor.get_attributes_from_observation(observation) | ||
| if attrs: | ||
| attributes.request_attributes.update(attrs) | ||
| return | ||
| elif final_response := observation.get("finalResponse"): | ||
| if text := final_response.get("text", ""): | ||
| attributes.request_attributes.update(get_output_attributes(text)) | ||
| return | ||
|
|
||
| # Then check spans | ||
| for span in trace_span.spans[::-1]: | ||
| for trace_data in span.chunks[::-1]: | ||
| trace_event = AttributeExtractor.get_event_type(trace_data) | ||
|
|
@@ -382,6 +427,30 @@ def _set_parent_span_output_attributes( | |
| return attributes.request_attributes.update( | ||
| get_output_attributes(output_text) | ||
| ) | ||
| # For Routing classifier events, the output is in rawResponse.content | ||
| if parsed_response == {}: | ||
| raw_response = model_invocation_output.get("rawResponse", {}) | ||
| if raw_response_content := raw_response.get("content"): | ||
| try: | ||
| response_content_json = json.loads(raw_response_content) | ||
| if ( | ||
| output_content := response_content_json.get("output", {}) | ||
| .get("message", {}) | ||
| .get("content") | ||
| ): | ||
| return attributes.request_attributes.update( | ||
| get_output_attributes(output_content) | ||
| ) | ||
| # Return full parsed json if output isn't found | ||
| return attributes.request_attributes.update( | ||
| get_output_attributes(response_content_json) | ||
| ) | ||
| except Exception: | ||
| pass | ||
| # Fallback to raw response if content was not valid JSON | ||
| return attributes.request_attributes.update( | ||
| get_output_attributes(raw_response_content) | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @luke-moehlenbrock I think this is fine given what |
||
|
|
||
| # Extract from invocation input | ||
| if "observation" in event_data: | ||
|
|
@@ -468,7 +537,13 @@ def _prepare_span_attributes(cls, trace_span_data: Union[TraceSpan, TraceNode]) | |
|
|
||
| # Set name from node type if it's a TraceNode | ||
| if isinstance(trace_span_data, TraceNode): | ||
| _attributes.name = trace_span_data.node_type | ||
| if trace_span_data.node_type == "guardrailTrace": | ||
| pre_or_post = trace_span_data.node_trace_id.split("-")[-1] | ||
| _attributes.name = pre_or_post + "GuardrailTrace" | ||
| else: | ||
| _attributes.name = trace_span_data.node_type | ||
| cls._process_trace_node(trace_span_data, _attributes) | ||
| return _attributes | ||
|
|
||
| # Process each chunk in the trace span | ||
| for trace_data in trace_span_data.chunks: | ||
|
|
@@ -502,6 +577,57 @@ def _prepare_span_attributes(cls, trace_span_data: Union[TraceSpan, TraceNode]) | |
| cls._process_failure_trace(event_data, _attributes) | ||
| return _attributes | ||
|
|
||
| @classmethod | ||
| def _process_trace_node(cls, trace_node: TraceNode, attributes: _Attributes) -> None: | ||
| """ | ||
| Process trace node data. This extracts metadata attributes and adds them | ||
| to the trace node span. This includes routingClassifierTrace, orchestrationTrace, | ||
| guardrailTrace, etc. | ||
| """ | ||
| for trace_data in trace_node.chunks: | ||
| trace_event = AttributeExtractor.get_event_type(trace_data) | ||
| event_data = trace_data.get(trace_event, {}) | ||
|
|
||
| # Extract agent collaborator name for agent-collaborator nodes | ||
| if trace_node.node_type == "agent-collaborator" and "invocationInput" in event_data: | ||
| invocation_input = event_data.get("invocationInput", {}) | ||
| if "agentCollaboratorInvocationInput" in invocation_input: | ||
| agent_collaborator_name = invocation_input.get( | ||
| "agentCollaboratorInvocationInput", {} | ||
| ).get("agentCollaboratorName", "") | ||
| invocation_type = invocation_input.get("invocationType", "") | ||
| if agent_collaborator_name: | ||
| attributes.name = f"{invocation_type.lower()}[{agent_collaborator_name}]" | ||
|
|
||
| # Extract child-level metadata first (will be overridden by trace-level metadata) | ||
| if "modelInvocationOutput" in event_data: | ||
| model_invocation_output = event_data.get("modelInvocationOutput", {}) | ||
| attributes.metadata.update( | ||
| AttributeExtractor.get_metadata_attributes( | ||
| model_invocation_output.get("metadata") | ||
| ) | ||
| ) | ||
| if observation := event_data.get("observation"): | ||
| # For agent-collaborator nodes, extract metadata from the observation itself | ||
| if trace_node.node_type == "agent-collaborator": | ||
| if observation_metadata := observation.get("metadata"): | ||
| attributes.metadata.update( | ||
| AttributeExtractor.get_observation_metadata_attributes( | ||
| observation_metadata | ||
| ) | ||
| ) | ||
| # For other nodes, extract from finalResponse if present | ||
| if final_response := observation.get("finalResponse"): | ||
| if final_response_metadata := final_response.get("metadata"): | ||
| attributes.metadata.update( | ||
| AttributeExtractor.get_metadata_attributes(final_response_metadata) | ||
| ) | ||
|
|
||
| # Extract trace-level metadata last so it takes precedence | ||
| # (for orchestrationTrace, guardrailTrace, etc.) | ||
| if metadata := event_data.get("metadata"): | ||
| attributes.metadata.update(AttributeExtractor.get_metadata_attributes(metadata)) | ||
|
|
||
| @classmethod | ||
| def _process_model_invocation_input( | ||
| cls, event_data: Dict[str, Any], attributes: _Attributes | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,11 @@ def generate_unique_trace_id(event_type: str, trace_id: str) -> str: | |
| Returns: | ||
| A unique trace ID string | ||
| """ | ||
| if "guardrail" in trace_id: | ||
| # if guardrail, use the first 7 parts of the trace id in order to differentiate | ||
| # between pre and post guardrail; it will look something like this: | ||
| # 4ce64021-13b2-23c5-9d70-aaefe8881138-guardrail-pre-0 | ||
| return f"{event_type}_{'-'.join(trace_id.split('-')[:7])}" | ||
| return f"{event_type}_{'-'.join(trace_id.split('-')[:5])}" | ||
|
|
||
|
|
||
|
|
@@ -210,6 +215,8 @@ def _handle_chunk_for_current_node( | |
| trace_data: The trace data to add to the node or span | ||
| """ | ||
| if chunk_type not in ["invocationInput", "modelInvocationInput"]: | ||
| # Add chunk to trace node as well, useful for propogating metadata to the parent node | ||
| node.add_chunk(trace_data) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Trace Data Duplication Causes Processing ErrorsPotential duplicate chunk data: The code adds trace_data to node.chunks in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generally, the first chunk will have invocationInput or modelInvocationInput, this might not always be the case but it should be the case for most TraceNodes. In the case where we add a duplicate chunk to a node, it shouldn't have any impact. The function that processes span attributes loops through the chunks in a node and uses each chunk to write attributes to the _attributes object. For duplicate chunks, the second chunk would just overwrite the attributes with the exact same values so there shouldn't be any unexpected behaviour here. We need to add the chunks here to ensure that routingClassifierTrace and orchestrationTrace properly capture the output from their respective events. |
||
| if node.current_span: | ||
| node.current_span.add_chunk(trace_data) | ||
| else: | ||
|
|
@@ -247,11 +254,13 @@ def _handle_new_trace_node( | |
| trace_node.chunks.append(trace_data) | ||
| elif event_type == "guardrailTrace": | ||
| trace_node = TraceNode(node_trace_id, event_type) | ||
| trace_node.add_chunk(trace_data) | ||
| trace_span = TraceSpan(chunk_type) | ||
| trace_span.add_chunk(trace_data) | ||
| trace_node.add_span(trace_span) | ||
| else: | ||
| trace_node = TraceNode(node_trace_id, event_type) | ||
| trace_node.add_chunk(trace_data) | ||
| trace_span = TraceSpan(chunk_type) | ||
| trace_span.add_chunk(trace_data) | ||
| trace_span.parent_node = parent_node # This is child for the Agent Span | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does input data/ metadata data show up correctly for routing classifier traces?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call on
raw_response_content, I'll change the name to something a little more clear.As for the routingClassifierTrace, it does have the input/output data from the LLM call, but the metadata isn't captured. This is due to how the TraceNodes are handled, the function loops through the chunks of the
trace_span_dataand uses the events in the chunks to populate the span attributes, but it also sets the name of the span. If we add the trace data as a chunk to the TraceNode then_prepare_span_attributeswill see themodelInvocationInputormodelInvocationOutputattributes in the trace data which will cause it to reassign the name toLLMinstead ofroutingClassifierTraceororchestrationTrace. We could probably get around that issue though by adding a_process_trace_nodemethod that only extracts the metadata from the chunks without reassigning the span name