11import contextlib
22import contextvars
3+ from functools import lru_cache
34import json
45import pathlib
56from typing import Any , Dict , Iterator , List , Optional
@@ -106,7 +107,54 @@ def _clean_gql_response(response: Any) -> Any:
106107 return response
107108
108109
109- @mcp .tool (description = "Get an entity by its DataHub URN." )
110+ class SemanticVersionStruct (BaseModel ):
111+ semantic_version : str
112+ version_stamp : str
113+
114+ @classmethod
115+ def from_dict (cls , data : Dict [str , Any ]) -> "SemanticVersionStruct" :
116+ return cls (
117+ semantic_version = data ["semanticVersion" ],
118+ version_stamp = data ["versionStamp" ],
119+ )
120+
121+
122+ class SchemaVersionList (BaseModel ):
123+ latest_version : SemanticVersionStruct
124+ versions : list [SemanticVersionStruct ]
125+
126+
127+ def _get_schema_version_list (
128+ datahub_client : DataHubClient , dataset_urn : str
129+ ) -> SchemaVersionList | None :
130+ variables = {
131+ "input" : {
132+ "datasetUrn" : dataset_urn ,
133+ }
134+ }
135+ resp = _execute_graphql (
136+ datahub_client ._graph ,
137+ query = entity_details_fragment_gql ,
138+ variables = variables ,
139+ operation_name = "getSchemaVersionList" ,
140+ )
141+ if not (raw_schema_versions := resp .get ("getSchemaVersionList" )):
142+ return None
143+
144+ return SchemaVersionList (
145+ latest_version = SemanticVersionStruct .from_dict (
146+ raw_schema_versions .get ("latestVersion" , {})
147+ ),
148+ versions = [
149+ SemanticVersionStruct .from_dict (structs )
150+ for structs in raw_schema_versions .get ("semanticVersionList" , [])
151+ ],
152+ )
153+
154+
155+ @mcp .tool (
156+ description = "Get an entity by its DataHub URN. This also provide schema_version_list(latest version, all versions) if available."
157+ )
110158def get_entity (urn : str ) -> dict :
111159 client = get_client ()
112160
@@ -125,6 +173,12 @@ def get_entity(urn: str) -> dict:
125173
126174 _inject_urls_for_urns (client ._graph , result , ["" ])
127175
176+ if schema_version_list := _get_schema_version_list (client , urn ):
177+ result ["schemaVersionList" ] = {
178+ "latestVersion" : schema_version_list .latest_version .semantic_version ,
179+ "versions" : sorted ([v .semantic_version for v in schema_version_list .versions ]),
180+ }
181+
128182 return _clean_gql_response (result )
129183
130184
@@ -313,6 +367,34 @@ def get_lineage(urn: str, upstream: bool, max_hops: int = 1) -> dict:
313367 return lineage
314368
315369
370+ @mcp .tool (description = "Get schema from a dataset by its URN and version." )
371+ @lru_cache
372+ def get_versioned_dataset (dataset_urn : str , semantic_version : str ) -> dict [str , Any ]:
373+ client = get_client ()
374+
375+ if not (schema_version_list := _get_schema_version_list (client , dataset_urn )):
376+ raise ValueError (f"No schema_version_list found for dataset { dataset_urn } " )
377+
378+ version_stamp_mapping = {
379+ struct .semantic_version : struct .version_stamp
380+ for struct in schema_version_list .versions
381+ }
382+
383+ if not (target_version_stamp := version_stamp_mapping .get (semantic_version )):
384+ raise ValueError (
385+ f"Version '{ semantic_version } ' not found for dataset '{ dataset_urn } '"
386+ )
387+
388+ variables = {"urn" : dataset_urn , "versionStamp" : target_version_stamp }
389+ resp = _execute_graphql (
390+ client ._graph ,
391+ query = entity_details_fragment_gql ,
392+ variables = variables ,
393+ operation_name = "getVersionedDataset" ,
394+ )
395+ return resp .get ("versionedDataset" , {})
396+
397+
316398if __name__ == "__main__" :
317399 import sys
318400
@@ -348,3 +430,6 @@ def _divider() -> None:
348430 _divider ()
349431 print ("Getting queries" , urn )
350432 print (json .dumps (get_dataset_queries (urn ), indent = 2 ))
433+ _divider ()
434+ print (json .dumps (get_versioned_dataset (urn , sementic_version = "0.0.0" ), indent = 2 ))
435+ _divider ()
0 commit comments