Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
.PHONY: setup lint test clean format format-check

PY_FILES = src tests scripts

# Setup development environment
setup:
uv sync

# Format code with ruff
format:
uv run ruff format src tests
uv run ruff format $(PY_FILES)

# Check code formatting with ruff
format-check:
uv run ruff format --check src tests
uv run ruff format --check $(PY_FILES)

# Lint with ruff and mypy
lint:
uv run ruff check src tests
uv run mypy src tests
uv run ruff check $(PY_FILES)
uv run mypy $(PY_FILES)

# Run tests
test:
Expand Down
51 changes: 51 additions & 0 deletions scripts/test_main_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import json
from typing import Optional

import click
from datahub.sdk.main_client import DataHubClient

from mcp_server_datahub.mcp_server import (
get_dataset_queries,
get_entity,
get_lineage,
search,
set_client,
)


def _divider() -> None:
print("\n" + "-" * 80 + "\n")


@click.command()
@click.argument("urn_or_query", required=False)
def main(urn_or_query: Optional[str]) -> None:
set_client(DataHubClient.from_env())

if urn_or_query is None:
urn_or_query = "*"
print("No query provided, will use '*' query")

urn: Optional[str] = None
if urn_or_query.startswith("urn:"):
urn = urn_or_query
else:
search_data = search(query=urn_or_query)
for entity in search_data["searchResults"]:
print(entity["entity"]["urn"])
urn = search_data["searchResults"][0]["entity"]["urn"]
assert urn is not None

_divider()
print("Getting entity:", urn)
print(json.dumps(get_entity(urn), indent=2))
_divider()
print("Getting lineage:", urn)
print(json.dumps(get_lineage(urn, upstream=False, max_hops=3), indent=2))
_divider()
print("Getting queries", urn)
print(json.dumps(get_dataset_queries(urn), indent=2))


if __name__ == "__main__":
main()
28 changes: 18 additions & 10 deletions src/mcp_server_datahub/__main__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import importlib.metadata

import click
from datahub.ingestion.graph.client import get_default_graph
from datahub.ingestion.graph.config import ClientMode
from datahub.sdk.main_client import DataHubClient
from typing_extensions import Literal

from mcp_server_datahub.mcp_server import mcp, set_client

# Because we want to override the datahub_component, we can't use DataHubClient.from_env()
# and need to use the DataHubClient constructor directly.
mcp_version = importlib.metadata.version("mcp-server-datahub")
graph = get_default_graph(
client_mode=ClientMode.SDK,
datahub_component=f"mcp-server-datahub/{mcp_version}",
)
set_client(DataHubClient(graph=graph))

@click.command()
@click.option(
"--transport",
type=click.Choice(["stdio", "sse", "streamable-http"]),
default="stdio",
)
def main(transport: Literal["stdio", "sse", "streamable-http"]) -> None:
# Because we want to override the datahub_component, we can't use DataHubClient.from_env()
# and need to use the DataHubClient constructor directly.
mcp_version = importlib.metadata.version("mcp-server-datahub")
graph = get_default_graph(
client_mode=ClientMode.SDK,
datahub_component=f"mcp-server-datahub/{mcp_version}",
)
set_client(DataHubClient(graph=graph))

def main() -> None:
mcp.run()
mcp.run(transport=transport)


if __name__ == "__main__":
Expand Down
38 changes: 0 additions & 38 deletions src/mcp_server_datahub/mcp_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextlib
import contextvars
import json
import pathlib
from typing import Any, Dict, Iterator, List, Optional

Expand Down Expand Up @@ -311,40 +310,3 @@ def get_lineage(urn: str, upstream: bool, max_hops: int = 1) -> dict:
lineage = lineage_api.get_lineage(asset_lineage_directive)
_inject_urls_for_urns(client._graph, lineage, ["*.searchResults[].entity"])
return lineage


if __name__ == "__main__":
import sys

set_client(DataHubClient.from_env())

if len(sys.argv) > 1:
urn_or_query = sys.argv[1]
else:
urn_or_query = "*"
print("No query provided, will use '*' query")
urn: Optional[str] = None
if urn_or_query.startswith("urn:"):
urn = urn_or_query
else:
urn = None
query = urn_or_query
if urn is None:
search_data = search()
for entity in search_data["searchResults"]:
print(entity["entity"]["urn"])
urn = search_data["searchResults"][0]["entity"]["urn"]
assert urn is not None

def _divider() -> None:
print("\n" + "-" * 80 + "\n")

_divider()
print("Getting entity:", urn)
print(json.dumps(get_entity(urn), indent=2))
_divider()
print("Getting lineage:", urn)
print(json.dumps(get_lineage(urn, upstream=False, max_hops=3), indent=2))
_divider()
print("Getting queries", urn)
print(json.dumps(get_dataset_queries(urn), indent=2))