Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions recce/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,47 @@ async def list_tools() -> List[Tool]:
},
)
)
# analyze_model is local-only: it relies on the local dbt manifest and
# the adapter-side _full_cll_map cache. The cloud backend (RecceMCPCloudBackend)
# does not implement it, so advertising it in cloud mode would surface a
# ValueError("Unknown tool: analyze_model") when an agent tries to call it.
if self.backend is None:
tools.append(
Tool(
name="analyze_model",
description=(
"Parse a dbt model's compiled SQL into structured evidence (refs, projections, "
"filters, joins, group_by, having, order_by, aggregations, case_expressions, "
"distinct, has_subquery, has_cte, is_set_operation) and return its downstream "
"column impact (which other models and columns depend on it, 1 hop). "
"Single-environment tool — does not require target-base/ or git history. "
"Useful for understanding what a model does structurally and who would be "
"affected by changes to it. Only available with dbt adapter in local mode.\n\n"
"Notes on the structure: refs lists upstream tables/sources only (CTE aliases "
"are excluded). is_set_operation=true means the model is a UNION/INTERSECT/"
"EXCEPT; projections and filters are merged across all legs. downstream covers "
"only direct dependents — for transitive impact, traverse with get_cll.\n\n"
"Performance: the first call per session builds the full column-level lineage "
"map (potentially seconds on large projects); subsequent calls reuse the cached "
"map and are fast.\n\n"
"Returns: {model_id, structure: SqlStructure, downstream: {models, columns}}. "
"If sqlglot cannot parse the compiled SQL, structure.unparseable=true and "
"the agent should fall back to text-level inspection."
),
inputSchema={
"type": "object",
"properties": {
"model_id": {
"type": "string",
"description": (
"Full unique ID of the model to analyze " "(e.g., 'model.project.model_name')."
),
},
},
"required": ["model_id"],
},
)
)
tools.append(
Tool(
name="get_server_info",
Expand Down Expand Up @@ -1313,6 +1354,8 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
result = await self._tool_get_model(arguments)
elif name == "get_cll":
result = await self._tool_get_cll(arguments)
elif name == "analyze_model":
result = await self._tool_analyze_model(arguments)
elif name == "get_server_info":
result = await self._tool_get_server_info(arguments)
elif name == "select_nodes":
Expand Down Expand Up @@ -2058,6 +2101,44 @@ async def _tool_get_cll(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
)
return cll.model_dump(mode="json")

async def _tool_analyze_model(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Return structural analysis of a model's compiled SQL plus downstream column impact."""
model_id = arguments.get("model_id")
if not model_id:
raise ValueError("model_id is required")
if self.context.adapter_type != "dbt":
raise ValueError("analyze_model is only available with dbt adapter")

from recce.adapter.dbt_adapter import DbtAdapter
from recce.util.ast_analyze import (
analyze_sql,
collect_downstream,
get_compiled_sql_from_manifest,
)

dbt_adapter: DbtAdapter = self.context.adapter
compiled_sql = get_compiled_sql_from_manifest(dbt_adapter.manifest, model_id)
if compiled_sql is None:
raise ValueError(
f"Cannot resolve compiled SQL for {model_id}. "
"Run `dbt compile` to populate target/ before calling analyze_model."
)

# Mirror the dialect lookup used by build_full_cll_map (see DbtAdapter:1078)
# — manifest metadata first, then live adapter.type(). Without this, BigQuery /
# Snowflake / etc. fall through to sqlglot's default dialect and return unparseable.
dialect = getattr(dbt_adapter.manifest.metadata, "adapter_type", None) or dbt_adapter.adapter.type()
structure = analyze_sql(compiled_sql, dialect=dialect)

cll_data = dbt_adapter.build_full_cll_map()
downstream = collect_downstream(cll_data, model_id)

return {
"model_id": model_id,
"structure": structure.model_dump(mode="json"),
"downstream": downstream,
}

async def _tool_get_server_info(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get server context information"""
context = self.context
Expand Down
254 changes: 254 additions & 0 deletions recce/util/ast_analyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
from typing import Any, Optional

import sqlglot.expressions as exp
from pydantic import BaseModel, Field
from sqlglot import parse_one
from sqlglot.errors import SqlglotError

from recce.models.types import CllData


class JoinInfo(BaseModel):
table: str
join_type: str
condition: Optional[str] = None


class ProjectionInfo(BaseModel):
name: str
source_columns: list[str] = Field(default_factory=list)
is_aggregate: bool = False
is_derived: bool = False


class AggregationInfo(BaseModel):
function: str
column: Optional[str] = None


class SqlStructure(BaseModel):
refs: list[str] = Field(default_factory=list)
projections: list[ProjectionInfo] = Field(default_factory=list)
filters: list[str] = Field(default_factory=list)
joins: list[JoinInfo] = Field(default_factory=list)
group_by: list[str] = Field(default_factory=list)
having: list[str] = Field(default_factory=list)
order_by: list[str] = Field(default_factory=list)
aggregations: list[AggregationInfo] = Field(default_factory=list)
case_expressions: list[str] = Field(default_factory=list)
distinct: bool = False
has_subquery: bool = False
has_cte: bool = False
is_set_operation: bool = False
unparseable: bool = False


def _projection_from(select_item: exp.Expression) -> ProjectionInfo:
if isinstance(select_item, exp.Alias):
name = select_item.alias
inner = select_item.this
elif isinstance(select_item, exp.Star):
return ProjectionInfo(name="*")
else:
name = select_item.alias_or_name
inner = select_item

is_aggregate = inner.find(exp.AggFunc) is not None
is_derived = not isinstance(inner, exp.Column)
source_columns = [c.name for c in inner.find_all(exp.Column)]
return ProjectionInfo(
name=name,
source_columns=source_columns,
is_aggregate=is_aggregate,
is_derived=is_derived,
)


def _flatten_and(condition: exp.Expression) -> list[exp.Expression]:
if isinstance(condition, exp.And):
return _flatten_and(condition.left) + _flatten_and(condition.right)
return [condition]


def _join_info(join: exp.Join) -> JoinInfo:
side = join.args.get("side")
kind = join.args.get("kind")
if side:
join_type = side.upper()
elif kind and kind.upper() == "CROSS":
join_type = "CROSS"
else:
join_type = "INNER"

on = join.args.get("on")
condition = on.sql() if on is not None else None

table = join.this
table_name = table.name if isinstance(table, exp.Table) else table.sql()

return JoinInfo(table=table_name, join_type=join_type, condition=condition)


def _aggregation_info(agg: exp.AggFunc) -> AggregationInfo:
function = agg.key.upper()
inner = agg.this
if inner is None or isinstance(inner, exp.Star):
return AggregationInfo(function=function, column=None)
col = inner.find(exp.Column) if not isinstance(inner, exp.Column) else inner
column_name = col.name if col is not None else None
return AggregationInfo(function=function, column=column_name)


def _top_level_selects(tree: exp.Expression) -> list[exp.Select]:
"""Return the SELECT legs of the outermost tree.

For a plain SELECT, returns [tree]. For UNION / INTERSECT / EXCEPT (any
sqlglot ``exp.SetOperation`` subclass — Union, Intersect, and Except all
inherit from SetOperation independently, not from each other), recurses
into ``left`` / ``right`` and returns each leg's SELECT. Selects inside
CTEs and subqueries are not returned — those are walked separately by
``find_all`` for aggregations and case expressions.
"""
if isinstance(tree, exp.SetOperation):
return _top_level_selects(tree.left) + _top_level_selects(tree.right)
if isinstance(tree, exp.Select):
return [tree]
return []


def analyze_sql(compiled_sql: str, dialect: Optional[str] = None) -> SqlStructure:
try:
tree = parse_one(compiled_sql, dialect=dialect)
except SqlglotError:
return SqlStructure(unparseable=True)

# Exclude CTE alias names from refs — they are internal aliases, not
# upstream tables. Without this, every staging-style model with `WITH
# source AS (...) SELECT * FROM source` leaks the CTE name as a "ref".
# Only unqualified table references can resolve to a CTE; a qualified
# name like `raw.orders` still refers to the real table even if a CTE
# named `orders` exists in the same tree.
cte_names = {cte.alias_or_name for cte in tree.find_all(exp.CTE)}
refs = sorted(
{t.name for t in tree.find_all(exp.Table) if not (t.name in cte_names and not t.db and not t.catalog)}
)

projections: list[ProjectionInfo] = []
filters: list[str] = []
joins: list[JoinInfo] = []
group_by: list[str] = []
having: list[str] = []
order_by: list[str] = []
aggregations: list[AggregationInfo] = []

selects = _top_level_selects(tree)
distinct = False
for select in selects:
for select_item in select.expressions:
projections.append(_projection_from(select_item))

where = select.args.get("where")
if where is not None:
for predicate in _flatten_and(where.this):
filters.append(predicate.sql())

for join in select.args.get("joins") or []:
joins.append(_join_info(join))

group = select.args.get("group")
if group is not None:
for item in group.expressions:
group_by.append(item.sql())

having_clause = select.args.get("having")
if having_clause is not None:
having.append(having_clause.this.sql())

order = select.args.get("order")
if order is not None:
for item in order.expressions:
order_by.append(item.sql())

if select.args.get("distinct") is not None:
distinct = True

for agg in tree.find_all(exp.AggFunc):
aggregations.append(_aggregation_info(agg))

case_expressions = [c.sql() for c in tree.find_all(exp.Case)]

has_subquery = tree.find(exp.Subquery) is not None
has_cte = tree.find(exp.With) is not None
is_set_operation = isinstance(tree, exp.SetOperation)

return SqlStructure(
refs=refs,
projections=projections,
filters=filters,
joins=joins,
group_by=group_by,
having=having,
order_by=order_by,
aggregations=aggregations,
case_expressions=case_expressions,
distinct=distinct,
has_subquery=has_subquery,
has_cte=has_cte,
is_set_operation=is_set_operation,
)


_ANALYZABLE_RESOURCE_TYPES = frozenset({"model", "seed", "snapshot"})


def get_compiled_sql_from_manifest(manifest: Any, model_id: str) -> Optional[str]:
node = manifest.nodes.get(model_id)
if node is None:
return None
resource_type = getattr(node, "resource_type", None)
if resource_type is not None and resource_type not in _ANALYZABLE_RESOURCE_TYPES:
raise ValueError(
f"Node {model_id} has resource_type={resource_type!r}; "
f"analyze_model only supports {sorted(_ANALYZABLE_RESOURCE_TYPES)}."
)
compiled = getattr(node, "compiled_code", None)
return compiled or None


def collect_downstream(cll_data: CllData, model_id: str) -> dict:
"""Find models and columns downstream of model_id using CllData.child_map.

child_map is keyed by either node_id or column_id ({node_id}_{column_name}),
and values are sets of dependent ids. Column ownership is resolved via
CllColumn.table_id (populated by build_full_cll_map).
"""
if model_id not in cll_data.nodes:
return {"models": [], "columns": []}

models: set[str] = set()
columns: list[dict] = []
seen_cols: set[tuple] = set()

# Node-level downstream
for child_id in cll_data.child_map.get(model_id, set()):
if child_id == model_id:
continue
if child_id in cll_data.nodes:
models.add(child_id)

# Column-level downstream: walk each column of the source model
for col_name in cll_data.nodes[model_id].columns:
col_id = f"{model_id}_{col_name}"
for child_col_id in cll_data.child_map.get(col_id, set()):
child_col = cll_data.columns.get(child_col_id)
if child_col is None or child_col.table_id == model_id:
continue
entry = (child_col.table_id, child_col.name)
if entry in seen_cols:
continue
seen_cols.add(entry)
columns.append({"node": child_col.table_id, "column": child_col.name})
if child_col.table_id:
models.add(child_col.table_id)

return {"models": sorted(models), "columns": columns}
1 change: 1 addition & 0 deletions tests/test_mcp_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,7 @@ async def test_list_tools_returns_all_server_mode_tools(self, mcp_e2e_with_data)
"impact_analysis",
"get_model",
"get_cll",
"analyze_model",
"get_server_info",
"set_backend",
"select_nodes",
Expand Down
Loading
Loading