Skip to content

Commit c9551ba

Browse files
committed
Flush out python side and add unit test
1 parent 50974af commit c9551ba

File tree

4 files changed

+146
-15
lines changed

4 files changed

+146
-15
lines changed

python/datafusion/catalog.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,61 @@
3838

3939
__all__ = [
4040
"Catalog",
41+
"CatalogList",
4142
"CatalogProvider",
43+
"CatalogProviderList",
4244
"Schema",
4345
"SchemaProvider",
4446
"Table",
4547
]
4648

4749

50+
class CatalogList:
51+
"""DataFusion data catalog list."""
52+
53+
def __init__(self, catalog_list: df_internal.catalog.RawCatalogList) -> None:
54+
"""This constructor is not typically called by the end user."""
55+
self.catalog_list = catalog_list
56+
57+
def __repr__(self) -> str:
58+
"""Print a string representation of the catalog list."""
59+
return self.catalog_list.__repr__()
60+
61+
def names(self) -> set[str]:
62+
"""This is an alias for `catalog_names`."""
63+
return self.catalog_names()
64+
65+
def catalog_names(self) -> set[str]:
66+
"""Returns the list of schemas in this catalog."""
67+
return self.catalog_list.catalog_names()
68+
69+
@staticmethod
70+
def memory_catalog(ctx: SessionContext | None = None) -> CatalogList:
71+
"""Create an in-memory catalog provider list."""
72+
catalog_list = df_internal.catalog.RawCatalogList.memory_catalog(ctx)
73+
return CatalogList(catalog_list)
74+
75+
def catalog(self, name: str = "datafusion") -> Schema:
76+
"""Returns the catalog with the given ``name`` from this catalog."""
77+
catalog = self.catalog_list.catalog(name)
78+
79+
return (
80+
Catalog(catalog)
81+
if isinstance(catalog, df_internal.catalog.RawCatalog)
82+
else catalog
83+
)
84+
85+
def register_catalog(
86+
self,
87+
name: str,
88+
catalog: Catalog | CatalogProvider | CatalogProviderExportable,
89+
) -> Catalog | None:
90+
"""Register a catalog with this catalog list."""
91+
if isinstance(catalog, Catalog):
92+
return self.catalog_list.register_catalog(name, catalog.catalog)
93+
return self.catalog_list.register_catalog(name, catalog)
94+
95+
4896
class Catalog:
4997
"""DataFusion data catalog."""
5098

@@ -195,6 +243,38 @@ def kind(self) -> str:
195243
return self._inner.kind
196244

197245

246+
class CatalogProviderList(ABC):
247+
"""Abstract class for defining a Python based Catalog Provider List."""
248+
249+
@abstractmethod
250+
def catalog_names(self) -> set[str]:
251+
"""Set of the names of all catalogs in this catalog list."""
252+
...
253+
254+
@abstractmethod
255+
def catalog(self, name: str) -> Catalog | None:
256+
"""Retrieve a specific catalog from this catalog list."""
257+
...
258+
259+
def register_catalog( # noqa: B027
260+
self, name: str, catalog: CatalogProviderExportable | CatalogProvider | Catalog
261+
) -> None:
262+
"""Add a catalog to this catalog list.
263+
264+
This method is optional. If your catalog provides a fixed list of catalogs, you
265+
do not need to implement this method.
266+
"""
267+
268+
269+
class CatalogProviderListExportable(Protocol):
270+
"""Type hint for object that has __datafusion_catalog_provider_list__ PyCapsule.
271+
272+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProviderList.html
273+
"""
274+
275+
def __datafusion_catalog_provider_list__(self, session: Any) -> object: ...
276+
277+
198278
class CatalogProvider(ABC):
199279
"""Abstract class for defining a Python based Catalog Provider."""
200280

@@ -229,6 +309,15 @@ def deregister_schema(self, name: str, cascade: bool) -> None: # noqa: B027
229309
"""
230310

231311

312+
class CatalogProviderExportable(Protocol):
313+
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
314+
315+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
316+
"""
317+
318+
def __datafusion_catalog_provider__(self, session: Any) -> object: ...
319+
320+
232321
class SchemaProvider(ABC):
233322
"""Abstract class for defining a Python based Schema Provider."""
234323

python/datafusion/context.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@
3131

3232
import pyarrow as pa
3333

34-
from datafusion.catalog import Catalog
34+
from datafusion.catalog import (
35+
Catalog,
36+
CatalogList,
37+
CatalogProviderExportable,
38+
CatalogProviderList,
39+
CatalogProviderListExportable,
40+
)
3541
from datafusion.dataframe import DataFrame
3642
from datafusion.expr import sort_list_to_raw_sort_list
3743
from datafusion.record_batch import RecordBatchStream
@@ -91,15 +97,6 @@ class TableProviderExportable(Protocol):
9197
def __datafusion_table_provider__(self, session: Any) -> object: ... # noqa: D105
9298

9399

94-
class CatalogProviderExportable(Protocol):
95-
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
96-
97-
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
98-
"""
99-
100-
def __datafusion_catalog_provider__(self, session: Any) -> object: ... # noqa: D105
101-
102-
103100
class SessionConfig:
104101
"""Session configuration options."""
105102

@@ -832,6 +829,16 @@ def catalog_names(self) -> set[str]:
832829
"""Returns the list of catalogs in this context."""
833830
return self.ctx.catalog_names()
834831

832+
def register_catalog_provider_list(
833+
self,
834+
provider: CatalogProviderListExportable | CatalogProviderList | CatalogList,
835+
) -> None:
836+
"""Register a catalog provider list."""
837+
if isinstance(provider, CatalogList):
838+
self.ctx.register_catalog_provider_list(provider.catalog)
839+
else:
840+
self.ctx.register_catalog_provider_list(provider)
841+
835842
def register_catalog_provider(
836843
self, name: str, provider: CatalogProviderExportable | CatalogProvider | Catalog
837844
) -> None:

python/tests/test_catalog.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from typing import TYPE_CHECKING
20+
1921
import datafusion as dfn
2022
import pyarrow as pa
2123
import pyarrow.dataset as ds
2224
import pytest
23-
from datafusion import SessionContext, Table, udtf
25+
from datafusion import Catalog, SessionContext, Table, udtf
26+
27+
if TYPE_CHECKING:
28+
from datafusion.catalog import CatalogProvider, CatalogProviderExportable
2429

2530

2631
# Note we take in `database` as a variable even though we don't use
@@ -93,6 +98,34 @@ def deregister_schema(self, name, cascade: bool):
9398
del self.schemas[name]
9499

95100

101+
class CustomCatalogProviderList(dfn.catalog.CatalogProviderList):
102+
def __init__(self):
103+
self.catalogs = {"my_catalog": CustomCatalogProvider()}
104+
105+
def catalog_names(self) -> set[str]:
106+
return set(self.catalogs.keys())
107+
108+
def catalog(self, name: str) -> Catalog | None:
109+
return self.catalogs[name]
110+
111+
def register_catalog(
112+
self, name: str, catalog: CatalogProviderExportable | CatalogProvider | Catalog
113+
) -> None:
114+
self.catalogs[name] = catalog
115+
116+
117+
def test_python_catalog_provider_list(ctx: SessionContext):
118+
ctx.register_catalog_provider_list(CustomCatalogProviderList())
119+
120+
# Ensure `datafusion` catalog does not exist since
121+
# we replaced the catalog list
122+
assert ctx.catalog_names() == {"my_catalog"}
123+
124+
# Ensure registering works
125+
ctx.register_catalog_provider("second_catalog", CustomCatalogProvider())
126+
assert ctx.catalog_names() == {"my_catalog", "second_catalog"}
127+
128+
96129
def test_python_catalog_provider(ctx: SessionContext):
97130
ctx.register_catalog_provider("my_catalog", CustomCatalogProvider())
98131

src/catalog.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,9 @@ impl CatalogProvider for RustWrappedPyCatalogProvider {
458458
Python::attach(|py| {
459459
let provider = self.catalog_provider.bind(py);
460460
provider
461-
.getattr("schema_names")
462-
.and_then(|names| names.extract::<Vec<String>>())
461+
.call_method0("schema_names")
462+
.and_then(|names| names.extract::<HashSet<String>>())
463+
.map(|names| names.into_iter().collect())
463464
.unwrap_or_else(|err| {
464465
log::error!("Unable to get schema_names: {err}");
465466
Vec::default()
@@ -565,8 +566,9 @@ impl CatalogProviderList for RustWrappedPyCatalogProviderList {
565566
Python::attach(|py| {
566567
let provider = self.catalog_provider_list.bind(py);
567568
provider
568-
.getattr("catalog_names")
569-
.and_then(|names| names.extract::<Vec<String>>())
569+
.call_method0("catalog_names")
570+
.and_then(|names| names.extract::<HashSet<String>>())
571+
.map(|names| names.into_iter().collect())
570572
.unwrap_or_else(|err| {
571573
log::error!("Unable to get catalog_names: {err}");
572574
Vec::default()

0 commit comments

Comments
 (0)