Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions astro-airflow-mcp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ af config variables
af config pools

# Direct API access (any endpoint)
af api --endpoints # List all available endpoints
af api --endpoints --filter variable # Filter endpoints by pattern
af api ls # List all available endpoints
af api ls --filter variable # Filter endpoints by pattern
af api dags # GET /api/v{1,2}/dags
af api dags -F limit=10 # With query parameters
af api variables -X POST -F key=x -f value=y # Create variable
Expand Down Expand Up @@ -323,8 +323,8 @@ The `af api` command provides direct access to any Airflow REST API endpoint, si

```bash
# Discover available endpoints
af api --endpoints
af api --endpoints --filter variable
af api ls
af api ls --filter variable

# GET requests (default)
af api dags
Expand All @@ -346,7 +346,7 @@ af api dags -i
af api health --raw

# Get full OpenAPI spec
af api --spec
af api spec
```

**Field syntax:**
Expand Down
16 changes: 16 additions & 0 deletions astro-airflow-mcp/src/astro_airflow_mcp/adapters/airflow_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Any

import yaml

from astro_airflow_mcp.adapters.base import AirflowAdapter, NotFoundError


Expand Down Expand Up @@ -322,6 +324,20 @@ def get_config(self) -> dict[str, Any]:
"note": "Config endpoint may require expose_config=True in airflow.cfg",
}

def get_openapi_spec(self) -> dict[str, Any]:
"""Get the OpenAPI specification for the Airflow 2.x API.

Airflow 2.x serves the spec as YAML at /api/v1/openapi.yaml.
"""
result = self.raw_request("GET", "openapi.yaml", raw_endpoint=False)
if result["status_code"] >= 400:
raise Exception(f"HTTP {result['status_code']}: {result.get('body', 'Unknown error')}")
# Parse YAML
body = result["body"]
if isinstance(body, str):
return yaml.safe_load(body)
return body

def clear_task_instances(
self,
dag_id: str,
Expand Down
10 changes: 10 additions & 0 deletions astro-airflow-mcp/src/astro_airflow_mcp/adapters/airflow_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ def get_config(self) -> dict[str, Any]:
"""Get Airflow configuration."""
return self._call("config")

def get_openapi_spec(self) -> dict[str, Any]:
"""Get the OpenAPI specification for the Airflow 3.x API.

Airflow 3.x serves the spec as JSON at /openapi.json (no version prefix).
"""
result = self.raw_request("GET", "openapi.json", raw_endpoint=True)
if result["status_code"] >= 400:
raise Exception(f"HTTP {result['status_code']}: {result.get('body', 'Unknown error')}")
return result["body"]

# Airflow 3.x specific features

def get_task_instances(
Expand Down
8 changes: 8 additions & 0 deletions astro-airflow-mcp/src/astro_airflow_mcp/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,14 @@ def get_version(self) -> dict[str, Any]:
def get_config(self) -> dict[str, Any]:
"""Get Airflow configuration."""

@abstractmethod
def get_openapi_spec(self) -> dict[str, Any]:
"""Get the OpenAPI specification for the Airflow API.

Returns:
Parsed OpenAPI spec as a dict with 'openapi', 'paths', etc.
"""

# Task Instance Operations
@abstractmethod
def clear_task_instances(
Expand Down
139 changes: 62 additions & 77 deletions astro-airflow-mcp/src/astro_airflow_mcp/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Annotated, Any

import typer
import yaml

from astro_airflow_mcp.cli.context import get_adapter
from astro_airflow_mcp.cli.output import output_error, output_json
Expand Down Expand Up @@ -96,11 +95,61 @@ def format_output(
output_json(result["body"])


def _api_ls(filter_pattern: str | None = None) -> None:
"""List available API endpoints."""
try:
adapter = get_adapter()
spec_data = adapter.get_openapi_spec()
if isinstance(spec_data, dict) and "paths" in spec_data:
paths = sorted(spec_data["paths"].keys())
if filter_pattern:
paths = [p for p in paths if filter_pattern.lower() in p.lower()]
output_json({"endpoints": paths, "count": len(paths)})
else:
output_error("Could not parse OpenAPI spec")
raise typer.Exit(1)
except typer.Exit:
raise
except Exception as e:
output_error(str(e))
raise typer.Exit(1) from None


def _api_spec(include_headers: bool = False) -> None:
"""Fetch the full OpenAPI specification."""
try:
adapter = get_adapter()
spec_data = adapter.get_openapi_spec()
if include_headers:
# Wrap in a response-like structure for consistency
output_json(
{
"status_code": 200,
"headers": {},
"body": spec_data,
}
)
else:
output_json(spec_data)
except typer.Exit:
raise
except Exception as e:
output_error(str(e))
raise typer.Exit(1) from None


def api_command(
endpoint: Annotated[
str | None,
typer.Argument(
help="API endpoint path (e.g., 'dags', '/dags/my_dag'). Leading slash optional."
help="API endpoint path (e.g., 'dags', '/dags/my_dag'), or 'ls' to list endpoints, or 'spec' to get OpenAPI spec."
),
] = None,
filter_pattern: Annotated[
str | None,
typer.Option(
"--filter",
help="Filter endpoints containing this string (use with 'ls')",
),
] = None,
method: Annotated[
Expand Down Expand Up @@ -157,27 +206,6 @@ def api_command(
help="Use endpoint path as-is without API version prefix",
),
] = False,
spec: Annotated[
bool,
typer.Option(
"--spec",
help="Fetch full OpenAPI spec (JSON)",
),
] = False,
endpoints: Annotated[
bool,
typer.Option(
"--endpoints",
help="List available API endpoints",
),
] = False,
filter_pattern: Annotated[
str | None,
typer.Option(
"--filter",
help="Filter endpoints containing this string (use with --endpoints)",
),
] = None,
) -> None:
"""Make direct requests to any Airflow REST API endpoint.

Expand All @@ -188,10 +216,10 @@ def api_command(
Examples:

# List all available endpoints
af api --endpoints
af api ls

# Filter endpoints by pattern
af api --endpoints --filter variable
af api ls --filter variable

# List DAGs
af api dags
Expand All @@ -215,63 +243,20 @@ def api_command(
af api health --raw

# Fetch full OpenAPI spec
af api --spec
af api spec
"""
# Handle --spec flag
if spec:
try:
adapter = get_adapter()
# Try AF3 location first (/openapi.json), then AF2 (/api/v1/openapi.yaml)
result = adapter.raw_request("GET", "openapi.json", raw_endpoint=True)
is_yaml = False
if result["status_code"] == 404:
# Fall back to AF2 location (YAML format)
result = adapter.raw_request("GET", "openapi.yaml", raw_endpoint=False)
is_yaml = True
if result["status_code"] >= 400:
output_error(f"HTTP {result['status_code']}: {result.get('body', 'Unknown error')}")
# Parse YAML if needed
if is_yaml and isinstance(result["body"], str):
result["body"] = yaml.safe_load(result["body"])
format_output(result, include_headers=include)
except Exception as e:
output_error(str(e))
# Handle subcommand dispatch for reserved names
if endpoint == "ls":
_api_ls(filter_pattern=filter_pattern)
return

# Handle --endpoints flag
if endpoints:
try:
adapter = get_adapter()
# Try AF3 location first (/openapi.json), then AF2 (/api/v1/openapi.yaml)
result = adapter.raw_request("GET", "openapi.json", raw_endpoint=True)
is_yaml = False
if result["status_code"] == 404:
# Fall back to AF2 location (YAML format)
result = adapter.raw_request("GET", "openapi.yaml", raw_endpoint=False)
is_yaml = True
if result["status_code"] >= 400:
output_error(f"HTTP {result['status_code']}: {result.get('body', 'Unknown error')}")
# Parse YAML if needed, extract endpoint paths
spec_data = result["body"]
if is_yaml and isinstance(spec_data, str):
spec_data = yaml.safe_load(spec_data)
if isinstance(spec_data, dict) and "paths" in spec_data:
paths = sorted(spec_data["paths"].keys())
if filter_pattern:
paths = [p for p in paths if filter_pattern.lower() in p.lower()]
output_json({"endpoints": paths, "count": len(paths)})
else:
output_error("Could not parse OpenAPI spec")
except Exception as e:
output_error(str(e))
if endpoint == "spec":
_api_spec(include_headers=include)
return

# Endpoint is required if not fetching spec or endpoints
# Endpoint is required for direct API calls
if endpoint is None:
output_error(
"Endpoint is required. Use 'af api <endpoint>', 'af api --endpoints', or 'af api --spec'"
)
return
output_error("Endpoint is required. Use 'af api <endpoint>', 'af api ls', or 'af api spec'")
raise typer.Exit(1)

# Validate method
valid_methods = {"GET", "POST", "PATCH", "PUT", "DELETE"}
Expand Down
22 changes: 11 additions & 11 deletions astro-airflow-mcp/tests/integration/test_cli_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ def test_af_api_get_dag(self, cli_env):
assert output["dag_id"] == dag_id
print(f"af api got DAG: {dag_id}")

def test_af_api_endpoints_flag(self, cli_env):
"""Should list endpoints via --endpoints flag."""
def test_af_api_ls_subcommand(self, cli_env):
"""Should list endpoints via ls subcommand."""
result = runner.invoke(
app,
["api", "--endpoints"],
["api", "ls"],
env=cli_env,
)

Expand All @@ -173,13 +173,13 @@ def test_af_api_endpoints_flag(self, cli_env):
assert "endpoints" in output
assert "count" in output
assert output["count"] > 0
print(f"af api --endpoints returned {output['count']} endpoints")
print(f"af api ls returned {output['count']} endpoints")

def test_af_api_endpoints_with_filter(self, cli_env):
"""Should filter endpoints via --endpoints --filter."""
def test_af_api_ls_with_filter(self, cli_env):
"""Should filter endpoints via ls --filter."""
result = runner.invoke(
app,
["api", "--endpoints", "--filter", "dag"],
["api", "ls", "--filter", "dag"],
env=cli_env,
)

Expand All @@ -191,11 +191,11 @@ def test_af_api_endpoints_with_filter(self, cli_env):
assert "dag" in endpoint.lower()
print(f"Filtered to {output['count']} dag-related endpoints")

def test_af_api_spec_flag(self, cli_env):
"""Should fetch OpenAPI spec via --spec flag."""
def test_af_api_spec_subcommand(self, cli_env):
"""Should fetch OpenAPI spec via spec subcommand."""
result = runner.invoke(
app,
["api", "--spec"],
["api", "spec"],
env=cli_env,
)

Expand All @@ -204,7 +204,7 @@ def test_af_api_spec_flag(self, cli_env):
# Should have OpenAPI structure
assert "openapi" in output or "swagger" in output
assert "paths" in output
print(f"af api --spec returned OpenAPI spec with {len(output['paths'])} paths")
print(f"af api spec returned OpenAPI spec with {len(output['paths'])} paths")

def test_af_api_include_headers(self, cli_env):
"""Should include headers with -i flag."""
Expand Down
Loading