Skip to content

Commit 1c47031

Browse files
tests: add unit tests for helpers (#40)
* tests: add unit tests for helpers * Update tests/test_mcp_server_helpers.py Co-authored-by: Harshal Sheth <hsheth2@gmail.com> * fix chart urn in test --------- Co-authored-by: Harshal Sheth <hsheth2@gmail.com>
1 parent 0beb77a commit 1c47031

File tree

2 files changed

+213
-16
lines changed

2 files changed

+213
-16
lines changed

src/mcp_server_datahub/mcp_server.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _execute_graphql(
9696
)
9797

9898

99-
def _inject_urls_for_urns(
99+
def inject_urls_for_urns(
100100
graph: DataHubGraph, response: Any, json_paths: List[str]
101101
) -> None:
102102
if not _is_datahub_cloud(graph):
@@ -112,7 +112,7 @@ def _inject_urls_for_urns(
112112
item.update(new_item)
113113

114114

115-
def _maybe_convert_to_schema_field_urn(urn: str, column: Optional[str]) -> str:
115+
def maybe_convert_to_schema_field_urn(urn: str, column: Optional[str]) -> str:
116116
if column is not None:
117117
maybe_dataset_urn = Urn.from_string(urn)
118118
if not isinstance(maybe_dataset_urn, DatasetUrn):
@@ -131,7 +131,7 @@ def _maybe_convert_to_schema_field_urn(urn: str, column: Optional[str]) -> str:
131131
queries_gql = (pathlib.Path(__file__).parent / "gql/queries.gql").read_text()
132132

133133

134-
def _clean_gql_response(response: Any) -> Any:
134+
def clean_gql_response(response: Any) -> Any:
135135
if isinstance(response, dict):
136136
banned_keys = {
137137
"__typename",
@@ -141,19 +141,19 @@ def _clean_gql_response(response: Any) -> Any:
141141
for k, v in response.items():
142142
if k in banned_keys or v is None or v == []:
143143
continue
144-
cleaned_v = _clean_gql_response(v)
144+
cleaned_v = clean_gql_response(v)
145145
if cleaned_v is not None and cleaned_v != {}:
146146
cleaned_response[k] = cleaned_v
147147

148148
return cleaned_response
149149
elif isinstance(response, list):
150-
return [_clean_gql_response(item) for item in response]
150+
return [clean_gql_response(item) for item in response]
151151
else:
152152
return response
153153

154154

155-
def _clean_get_entity_response(raw_response: dict) -> dict:
156-
response = _clean_gql_response(raw_response)
155+
def clean_get_entity_response(raw_response: dict) -> dict:
156+
response = clean_gql_response(raw_response)
157157

158158
if response and (schema_metadata := response.get("schemaMetadata")):
159159
# Remove empty platformSchema to reduce response clutter
@@ -191,9 +191,9 @@ def get_entity(urn: str) -> dict:
191191
operation_name="GetEntity",
192192
)["entity"]
193193

194-
_inject_urls_for_urns(client._graph, result, [""])
194+
inject_urls_for_urns(client._graph, result, [""])
195195

196-
return _clean_get_entity_response(result)
196+
return clean_get_entity_response(result)
197197

198198

199199
@mcp.tool(
@@ -274,7 +274,7 @@ def search(
274274
response.pop("searchResults", None)
275275
response.pop("count", None)
276276

277-
return _clean_gql_response(response)
277+
return clean_gql_response(response)
278278

279279

280280
@mcp.tool(
@@ -286,7 +286,7 @@ def get_dataset_queries(
286286
) -> dict:
287287
client = get_datahub_client()
288288

289-
urn = _maybe_convert_to_schema_field_urn(urn, column)
289+
urn = maybe_convert_to_schema_field_urn(urn, column)
290290

291291
entities_filter = FilterDsl.custom_filter(
292292
field="entities", condition="EQUAL", values=[urn]
@@ -310,7 +310,7 @@ def get_dataset_queries(
310310
if query.get("subjects"):
311311
query["subjects"] = _deduplicate_subjects(query["subjects"])
312312

313-
return _clean_gql_response(result)
313+
return clean_gql_response(result)
314314

315315

316316
def _deduplicate_subjects(subjects: list[dict]) -> list[str]:
@@ -374,7 +374,7 @@ def get_lineage(
374374
"searchFlags": {"skipHighlighting": True, "maxAggValues": 3},
375375
}
376376
if asset_lineage_directive.upstream:
377-
result["upstreams"] = _clean_gql_response(
377+
result["upstreams"] = clean_gql_response(
378378
_execute_graphql(
379379
self.graph,
380380
query=entity_details_fragment_gql,
@@ -388,7 +388,7 @@ def get_lineage(
388388
)["searchAcrossLineage"]
389389
)
390390
if asset_lineage_directive.downstream:
391-
result["downstreams"] = _clean_gql_response(
391+
result["downstreams"] = clean_gql_response(
392392
_execute_graphql(
393393
self.graph,
394394
query=entity_details_fragment_gql,
@@ -430,7 +430,7 @@ def get_lineage(
430430

431431
lineage_api = AssetLineageAPI(client._graph)
432432

433-
urn = _maybe_convert_to_schema_field_urn(urn, column)
433+
urn = maybe_convert_to_schema_field_urn(urn, column)
434434
asset_lineage_directive = AssetLineageDirective(
435435
urn=urn,
436436
upstream=upstream,
@@ -439,5 +439,5 @@ def get_lineage(
439439
extra_filters=filters,
440440
)
441441
lineage = lineage_api.get_lineage(asset_lineage_directive)
442-
_inject_urls_for_urns(client._graph, lineage, ["*.searchResults[].entity"])
442+
inject_urls_for_urns(client._graph, lineage, ["*.searchResults[].entity"])
443443
return lineage

tests/test_mcp_server_helpers.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import pytest
2+
from unittest.mock import Mock, patch
3+
from mcp_server_datahub.mcp_server import (
4+
inject_urls_for_urns,
5+
maybe_convert_to_schema_field_urn,
6+
clean_gql_response,
7+
clean_get_entity_response,
8+
)
9+
from datahub.ingestion.graph.links import make_url_for_urn
10+
11+
12+
def test_inject_urls_for_urns():
13+
mock_graph = Mock()
14+
mock_graph.url_for.side_effect = lambda urn: make_url_for_urn(
15+
"https://xyz.com", urn
16+
)
17+
18+
with patch("mcp_server_datahub.mcp_server._is_datahub_cloud", return_value=True):
19+
response = {
20+
"searchResults": [
21+
{
22+
"entity": {
23+
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)",
24+
"name": "users",
25+
}
26+
},
27+
{
28+
"entity": {
29+
"urn": "urn:li:chart:(looker,baz)",
30+
"name": "baz",
31+
}
32+
},
33+
]
34+
}
35+
36+
inject_urls_for_urns(mock_graph, response, ["searchResults[].entity"])
37+
38+
expected_response = {
39+
"searchResults": [
40+
{
41+
"entity": {
42+
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)",
43+
"url": "https://xyz.com/dataset/urn%3Ali%3Adataset%3A%28urn%3Ali%3AdataPlatform%3Asnowflake%2Canalytics_db.raw_schema.users%2CPROD%29/",
44+
"name": "users",
45+
}
46+
},
47+
{
48+
"entity": {
49+
"urn": "urn:li:chart:(looker,baz)",
50+
"url": "https://xyz.com/chart/urn%3Ali%3Achart%3A%28looker%2Cbaz%29/",
51+
"name": "baz",
52+
}
53+
},
54+
]
55+
}
56+
57+
assert response == expected_response
58+
assert mock_graph.url_for.call_count == 2
59+
60+
61+
def test_maybe_convert_to_schema_field_urn_with_column():
62+
dataset_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)"
63+
column = "user_id"
64+
65+
result = maybe_convert_to_schema_field_urn(dataset_urn, column)
66+
67+
assert (
68+
result
69+
== "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD),user_id)"
70+
)
71+
72+
73+
def test_maybe_convert_to_schema_field_urn_without_column():
74+
original_urn = "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)"
75+
76+
result = maybe_convert_to_schema_field_urn(original_urn, None)
77+
78+
assert result == original_urn
79+
80+
81+
def test_maybe_convert_to_schema_field_urn_with_incorrect_entity():
82+
chart_urn = "urn:li:chart:(looker,baz)"
83+
84+
# Ok if no column is provided
85+
result = maybe_convert_to_schema_field_urn(chart_urn, None)
86+
assert result == chart_urn
87+
88+
# Fail if column is provided
89+
column = "user_id"
90+
with pytest.raises(ValueError):
91+
maybe_convert_to_schema_field_urn(chart_urn, column)
92+
93+
94+
def test_clean_gql_response_with_dict():
95+
response = {
96+
"__typename": "Dataset",
97+
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)",
98+
"name": "users",
99+
"description": None,
100+
"tags": [],
101+
"properties": {"__typename": "Properties", "key1": "value1", "key2": None},
102+
}
103+
104+
result = clean_gql_response(response)
105+
106+
expected_result = {
107+
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)",
108+
"name": "users",
109+
"properties": {"key1": "value1"},
110+
}
111+
112+
assert result == expected_result
113+
114+
115+
def test_clean_gql_response_with_nested_empty_objects():
116+
response = {
117+
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)",
118+
"name": "users",
119+
"empty_object": {},
120+
"empty_array": [],
121+
"nested_object": {
122+
"empty_object": {},
123+
"empty_array": [],
124+
"valid_data": "value",
125+
"null_value": None,
126+
},
127+
"array_of_objects": [
128+
{"valid": "data", "empty_object": {}, "empty_array": [], "null_value": None}
129+
],
130+
}
131+
132+
result = clean_gql_response(response)
133+
134+
expected_result = {
135+
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)",
136+
"name": "users",
137+
"nested_object": {"valid_data": "value"},
138+
"array_of_objects": [{"valid": "data"}],
139+
}
140+
141+
assert result == expected_result
142+
143+
144+
def test_clean_get_entity_response_with_schema_metadata():
145+
raw_response = {
146+
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)",
147+
"name": "users",
148+
"schemaMetadata": {
149+
"platformSchema": {
150+
"schema": "" # Empty schema should be removed
151+
},
152+
"fields": [
153+
{
154+
"fieldPath": "user_id",
155+
"recursive": False, # Should be removed
156+
"isPartOfKey": False, # Should be removed
157+
"type": "STRING",
158+
},
159+
{
160+
"fieldPath": "email",
161+
"recursive": True, # Should be kept
162+
"isPartOfKey": True, # Should be kept
163+
"type": "STRING",
164+
},
165+
{
166+
"fieldPath": "created_at",
167+
"recursive": None, # Should be removed
168+
"isPartOfKey": None, # Should be removed
169+
"type": "TIMESTAMP",
170+
},
171+
],
172+
},
173+
}
174+
175+
result = clean_get_entity_response(raw_response)
176+
177+
expected_result = {
178+
"urn": "urn:li:dataset:(urn:li:dataPlatform:snowflake,analytics_db.raw_schema.users,PROD)",
179+
"name": "users",
180+
"schemaMetadata": {
181+
"fields": [
182+
{"fieldPath": "user_id", "type": "STRING"},
183+
{
184+
"fieldPath": "email",
185+
"recursive": True,
186+
"isPartOfKey": True,
187+
"type": "STRING",
188+
},
189+
{
190+
"fieldPath": "created_at",
191+
"type": "TIMESTAMP",
192+
},
193+
]
194+
},
195+
}
196+
197+
assert result == expected_result

0 commit comments

Comments
 (0)