Skip to content

Commit ed5a51c

Browse files
authored
Add endpoint to list available dialects (#1393)
* Add endpoint to list available dialects * Add support for GraphQL querying of dialects
1 parent 99d2b1e commit ed5a51c

File tree

8 files changed

+142
-5
lines changed

8 files changed

+142
-5
lines changed

datajunction-server/datajunction_server/api/engines.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,27 @@
1313
from datajunction_server.internal.access.authentication.http import SecureAPIRouter
1414
from datajunction_server.internal.engines import get_engine
1515
from datajunction_server.models.engine import EngineInfo
16+
from datajunction_server.models.dialect import DialectRegistry, DialectInfo
1617
from datajunction_server.utils import get_session, get_settings
1718

1819
settings = get_settings()
1920
router = SecureAPIRouter(tags=["engines"])
2021

2122

23+
@router.get("/dialects", response_model=list[DialectInfo])
24+
async def list_dialects():
25+
"""
26+
Returns a list of registered SQL dialects and their associated transpilation plugin class names.
27+
"""
28+
return [
29+
DialectInfo(
30+
name=dialect,
31+
plugin_class=plugin.__name__,
32+
)
33+
for dialect, plugin in DialectRegistry._registry.items()
34+
]
35+
36+
2237
@router.get("/engines/", response_model=List[EngineInfo])
2338
async def list_engines(
2439
*,

datajunction-server/datajunction_server/api/graphql/main.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,21 @@
77
from fastapi import Depends
88
from strawberry.fastapi import GraphQLRouter
99
from strawberry.types import Info
10-
1110
from datajunction_server.api.graphql.queries.catalogs import list_catalogs
1211
from datajunction_server.api.graphql.queries.dag import common_dimensions
13-
from datajunction_server.api.graphql.queries.engines import list_engines
12+
from datajunction_server.api.graphql.queries.engines import list_engines, list_dialects
1413
from datajunction_server.api.graphql.queries.nodes import (
1514
find_nodes,
1615
find_nodes_paginated,
1716
)
1817
from datajunction_server.api.graphql.queries.sql import measures_sql
1918
from datajunction_server.api.graphql.queries.tags import list_tag_types, list_tags
2019
from datajunction_server.api.graphql.scalars import Connection
21-
from datajunction_server.api.graphql.scalars.catalog_engine import Catalog, Engine
20+
from datajunction_server.api.graphql.scalars.catalog_engine import (
21+
Catalog,
22+
Engine,
23+
DialectInfo,
24+
)
2225
from datajunction_server.api.graphql.scalars.node import DimensionAttribute, Node
2326
from datajunction_server.api.graphql.scalars.sql import GeneratedSQL
2427
from datajunction_server.api.graphql.scalars.tag import Tag
@@ -81,9 +84,15 @@ class Query:
8184
# Catalog and engine queries
8285
list_catalogs: list[Catalog] = strawberry.field(
8386
resolver=log_resolver(list_catalogs),
87+
description="List available catalogs",
8488
)
8589
list_engines: list[Engine] = strawberry.field(
8690
resolver=log_resolver(list_engines),
91+
description="List all available engines",
92+
)
93+
list_dialects: list[DialectInfo] = strawberry.field(
94+
resolver=log_resolver(list_dialects),
95+
description="List all supported SQL dialects",
8796
)
8897

8998
# Node search queries
@@ -105,6 +114,7 @@ class Query:
105114
# Generate SQL queries
106115
measures_sql: list[GeneratedSQL] = strawberry.field(
107116
resolver=log_resolver(measures_sql),
117+
description="Get measures SQL for a list of metrics, dimensions, and filters.",
108118
)
109119

110120
# Tags queries

datajunction-server/datajunction_server/api/graphql/queries/engines.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from sqlalchemy import select
88
from strawberry.types import Info
99

10-
from datajunction_server.api.graphql.scalars.catalog_engine import Engine
10+
from datajunction_server.models.dialect import DialectRegistry
11+
from datajunction_server.api.graphql.scalars.catalog_engine import Engine, DialectInfo
1112
from datajunction_server.database.engine import Engine as DBEngine
1213

1314

@@ -23,3 +24,19 @@ async def list_engines(
2324
Engine.from_pydantic(engine) # type: ignore #pylint: disable=E1101
2425
for engine in (await session.execute(select(DBEngine))).scalars().all()
2526
]
27+
28+
29+
async def list_dialects(
30+
*,
31+
info: Info = None,
32+
) -> List[DialectInfo]:
33+
"""
34+
List all supported dialects
35+
"""
36+
return [
37+
DialectInfo( # type: ignore
38+
name=dialect,
39+
plugin_class=plugin.__name__,
40+
)
41+
for dialect, plugin in DialectRegistry._registry.items()
42+
]

datajunction-server/datajunction_server/api/graphql/scalars/catalog_engine.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from datajunction_server.models.catalog import CatalogInfo
66
from datajunction_server.models.engine import EngineInfo
7-
from datajunction_server.models.node import Dialect as Dialect_
7+
from datajunction_server.models.dialect import Dialect as Dialect_
8+
from datajunction_server.models.dialect import DialectInfo as DialectInfo_
89

910
Dialect = strawberry.enum(Dialect_)
1011

@@ -21,3 +22,10 @@ class Catalog:
2122
"""
2223
Class for a Catalog
2324
"""
25+
26+
27+
@strawberry.experimental.pydantic.type(model=DialectInfo_, all_fields=True)
28+
class DialectInfo:
29+
"""
30+
Class for DialectInfo
31+
"""

datajunction-server/datajunction_server/api/graphql/schema.graphql

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ enum Dialect {
8585
DRUID
8686
}
8787

88+
type DialectInfo {
89+
name: String!
90+
pluginClass: String!
91+
}
92+
8893
type DimensionAttribute {
8994
name: String!
9095
attribute: String
@@ -331,9 +336,15 @@ enum PartitionType {
331336
}
332337

333338
type Query {
339+
"""List available catalogs"""
334340
listCatalogs: [Catalog!]!
341+
342+
"""List all available engines"""
335343
listEngines: [Engine!]!
336344

345+
"""List all supported SQL dialects"""
346+
listDialects: [DialectInfo!]!
347+
337348
"""Find nodes based on the search parameters."""
338349
findNodes(
339350
"""A fragment of a node name to search for"""
@@ -383,6 +394,8 @@ type Query {
383394
"""A list of nodes to find common dimensions for"""
384395
nodes: [String!] = null
385396
): [DimensionAttribute!]!
397+
398+
"""Get measures SQL for a list of metrics, dimensions, and filters."""
386399
measuresSql(
387400
cube: CubeDefinition!
388401
engine: EngineSettings = null

datajunction-server/datajunction_server/models/dialect.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
3+
from pydantic import BaseModel
24
from datajunction_server.enum import StrEnum
35
from datajunction_server.utils import get_settings
46
from typing import TYPE_CHECKING
@@ -42,6 +44,15 @@ def _missing_(cls, value: object) -> "Dialect":
4244
raise TypeError(f"{value!r} is not a valid string for {cls.__name__}")
4345

4446

47+
class DialectInfo(BaseModel):
48+
"""
49+
Information about a SQL dialect and its associated plugin class.
50+
"""
51+
52+
name: str
53+
plugin_class: str
54+
55+
4556
class DialectRegistry:
4657
"""
4758
Registry for SQL dialect plugins.

datajunction-server/tests/api/engine_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,28 @@ async def test_engine_raise_on_engine_already_exists(
151151
assert response.status_code == 409
152152
data = response.json()
153153
assert data == {"detail": "Engine already exists: `spark-three` version `3.3.1`"}
154+
155+
156+
@pytest.mark.asyncio
157+
async def test_dialects_list(
158+
module__client: AsyncClient,
159+
) -> None:
160+
"""
161+
Test listing dialects
162+
"""
163+
response = await module__client.get("/dialects/")
164+
assert response.status_code == 200
165+
assert response.json() == [
166+
{
167+
"name": "spark",
168+
"plugin_class": "SQLTranspilationPlugin",
169+
},
170+
{
171+
"name": "druid",
172+
"plugin_class": "SQLTranspilationPlugin",
173+
},
174+
{
175+
"name": "trino",
176+
"plugin_class": "SQLTranspilationPlugin",
177+
},
178+
]

datajunction-server/tests/api/graphql/engine_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,41 @@ async def test_engine_list(
6262
],
6363
},
6464
}
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_list_dialects(
69+
client: AsyncClient,
70+
) -> None:
71+
"""
72+
Test listing dialects
73+
"""
74+
query = """
75+
{
76+
listDialects{
77+
name
78+
pluginClass
79+
}
80+
}
81+
"""
82+
response = await client.post("/graphql", json={"query": query})
83+
assert response.status_code == 200
84+
data = response.json()
85+
assert data == {
86+
"data": {
87+
"listDialects": [
88+
{
89+
"name": "spark",
90+
"pluginClass": "SQLTranspilationPlugin",
91+
},
92+
{
93+
"name": "druid",
94+
"pluginClass": "SQLTranspilationPlugin",
95+
},
96+
{
97+
"name": "trino",
98+
"pluginClass": "SQLTranspilationPlugin",
99+
},
100+
],
101+
},
102+
}

0 commit comments

Comments
 (0)