Skip to content
This repository was archived by the owner on Jul 16, 2024. It is now read-only.

Make it work with PL/Container #223

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support PL/Container
  • Loading branch information
Xuebin Su committed Oct 27, 2023
commit c00bbdcbaa0e8d6ab5093fc3b26e4f40f5a29d9f
125 changes: 75 additions & 50 deletions greenplumpython/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from greenplumpython.db import Database
from greenplumpython.expr import Expr, _serialize_to_expr
from greenplumpython.group import DataFrameGroupingSet
from greenplumpython.type import _serialize_to_type
from greenplumpython.type import _serialize_to_type_name, _defined_types

import psycopg2

class FunctionExpr(Expr):
"""
Expand Down Expand Up @@ -111,52 +112,68 @@ def apply(
if grouping_col_names is not None and len(grouping_col_names) != 0
else None
)
unexpanded_dataframe = DataFrame(
" ".join(
try:
return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType
_serialize_to_type_name(return_annotation, db=db, for_return=True)
return DataFrame(
f"""
SELECT * FROM plcontainer_apply(TABLE(
SELECT * {from_clause}), '{self._function._qualified_name_str}', 4096) AS
{_defined_types[return_annotation.__args__[0]]._serialize(db=db)}
""",
db=db,
parents=parents,
)
except psycopg2.errors.InternalError_:
unexpanded_dataframe = DataFrame(
" ".join(
[
f"SELECT {_serialize_to_expr(self, db=db)} {'AS ' + column_name if column_name is not None else ''}",
("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "",
from_clause,
group_by_clause,
]
),
db=db,
parents=parents,
)
# We use 2 `DataFrame`s here because on GPDB 6X and PostgreSQL <= 9.6, a
# function returning records that contains more than one attributes
# will be called multiple times if we do
# ```sql
# SELECT (func(a, b)).* FROM t;
# ```
# which might cause performance issue. To workaround we need to do
# ```sql
# WITH func_call AS (
# SELECT func(a, b) AS result FROM t
# )
# SELECT (result).* FROM func_call;
# ```
rebased_grouping_cols = (
[
f"SELECT {_serialize_to_expr(self, db=db)} {'AS ' + column_name if column_name is not None else ''}",
("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "",
from_clause,
group_by_clause,
_serialize_to_expr(unexpanded_dataframe[name], db=db)
for name in grouping_col_names
]
),
db=db,
parents=parents,
)
# We use 2 `DataFrame`s here because on GPDB 6X and PostgreSQL <= 9.6, a
# function returning records that contains more than one attributes
# will be called multiple times if we do
# ```sql
# SELECT (func(a, b)).* FROM t;
# ```
# which might cause performance issue. To workaround we need to do
# ```sql
# WITH func_call AS (
# SELECT func(a, b) AS result FROM t
# )
# SELECT (result).* FROM func_call;
# ```
rebased_grouping_cols = (
[_serialize_to_expr(unexpanded_dataframe[name], db=db) for name in grouping_col_names]
if grouping_col_names is not None
else None
)
result_cols = (
_serialize_to_expr(unexpanded_dataframe["*"], db=db)
if not expand
else _serialize_to_expr(unexpanded_dataframe[column_name]["*"], db=db)
# `len(rebased_grouping_cols) == 0` means `GROUP BY GROUPING SETS (())`
if rebased_grouping_cols is None or len(rebased_grouping_cols) == 0
else f"({unexpanded_dataframe._name}).*"
if not expand
else f"{','.join(rebased_grouping_cols)}, {_serialize_to_expr(unexpanded_dataframe[column_name]['*'], db=db)}"
)
if grouping_col_names is not None
else None
)
result_cols = (
_serialize_to_expr(unexpanded_dataframe["*"], db=db)
if not expand
else _serialize_to_expr(unexpanded_dataframe[column_name]["*"], db=db)
# `len(rebased_grouping_cols) == 0` means `GROUP BY GROUPING SETS (())`
if rebased_grouping_cols is None or len(rebased_grouping_cols) == 0
else f"({unexpanded_dataframe._name}).*"
if not expand
else f"{','.join(rebased_grouping_cols)}, {_serialize_to_expr(unexpanded_dataframe[column_name]['*'], db=db)}"
)

return DataFrame(
f"SELECT {result_cols} FROM {unexpanded_dataframe._name}",
db=db,
parents=[unexpanded_dataframe],
)
return DataFrame(
f"SELECT {result_cols} FROM {unexpanded_dataframe._name}",
db=db,
parents=[unexpanded_dataframe],
)

@property
def _function(self) -> "_AbstractFunction":
Expand Down Expand Up @@ -272,12 +289,14 @@ def __init__(
name: Optional[str] = None,
schema: Optional[str] = None,
language_handler: Literal["plpython3u"] = "plpython3u",
runtime: Optional[str] = None
) -> None:
# noqa D107
super().__init__(wrapped_func, name, schema)
self._created_in_dbs: Optional[Set[Database]] = set() if wrapped_func is not None else None
self._wrapped_func = wrapped_func
self._language_handler = language_handler
self._runtime = runtime

def unwrap(self) -> Callable[..., Any]:
"""Get the wrapped Python function in the database function."""
Expand All @@ -302,14 +321,18 @@ def _serialize(self, db: Database) -> str:
func_sig = inspect.signature(self._wrapped_func)
func_args = ",".join(
[
f'"{param.name}" {_serialize_to_type(param.annotation, db=db)}'
f'"{param.name}" {_serialize_to_type_name(param.annotation, db=db)}'
for param in func_sig.parameters.values()
]
)
func_arg_names = ",".join(
[f"{param.name}={param.name}" for param in func_sig.parameters.values()]
)
return_type = _serialize_to_type(func_sig.return_annotation, db=db, for_return=True)
return_type = (
_serialize_to_type_name(func_sig.return_annotation, db=db, for_return=True)
if self._language_handler != "plcontainer"
else "SETOF record"
)
func_pickled: bytes = dill.dumps(self._wrapped_func)
_, func_name = self._qualified_name
# Modify the AST of the wrapped function to minify dependency: (1-3)
Expand All @@ -335,6 +358,7 @@ def _serialize(self, db: Database) -> str:
f"CREATE FUNCTION {self._qualified_name_str} ({func_args}) "
f"RETURNS {return_type} "
f"AS $$\n"
f"# container: {self._runtime}\n"
f"try:\n"
f" return GD['{func_ast.name}']({func_arg_names})\n"
f"except KeyError:\n"
Expand Down Expand Up @@ -461,7 +485,7 @@ def _create_in_db(self, db: Database) -> None:
state_param = next(param_list)
args_string = ",".join(
[
f"{param.name} {_serialize_to_type(param.annotation, db=db)}"
f"{param.name} {_serialize_to_type_name(param.annotation, db=db)}"
for param in param_list
]
)
Expand All @@ -470,7 +494,7 @@ def _create_in_db(self, db: Database) -> None:
(
f"CREATE AGGREGATE {self._qualified_name_str} ({args_string}) (\n"
f" SFUNC = {self.transition_function._qualified_name_str},\n"
f" STYPE = {_serialize_to_type(state_param.annotation, db=db)}\n"
f" STYPE = {_serialize_to_type_name(state_param.annotation, db=db)}\n"
f");\n"
),
has_results=False,
Expand Down Expand Up @@ -547,6 +571,7 @@ def aggregate_function(name: str, schema: Optional[str] = None) -> AggregateFunc
def create_function(
wrapped_func: Optional[Callable[..., Any]] = None,
language_handler: Literal["plpython3u"] = "plpython3u",
runtime: Optional[str] = None
) -> NormalFunction:
"""
Create a :class:`~func.NormalFunction` from the given Python function.
Expand Down Expand Up @@ -610,8 +635,8 @@ def create_function(
"""
# If user needs extra parameters when creating a function
if wrapped_func is None:
return functools.partial(create_function, language_handler=language_handler)
return NormalFunction(wrapped_func=wrapped_func, language_handler=language_handler)
return functools.partial(create_function, language_handler=language_handler, runtime=runtime)
return NormalFunction(wrapped_func=wrapped_func, language_handler=language_handler, runtime=runtime)


# FIXME: Add test cases for optional parameters
Expand Down
31 changes: 19 additions & 12 deletions greenplumpython/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ def __init__(
if self._modifier is not None:
self._qualified_name_str += f"({self._modifier})"

def _serialize(self, db: Database) -> str:
if self._annotation is None:
raise Exception("No type annotation to serialize")
members = get_type_hints(self._annotation)
if len(members) == 0:
raise Exception(f"Failed to get annotations for type {self._annotation}")
members_str = ",\n".join(
[f"{name} {_serialize_to_type_name(type_t, db)}" for name, type_t in members.items()]
)
return f"({members_str})"

# -- Creation of a composite type in Greenplum corresponding to the class_type given
def _create_in_db(self, db: Database):
# noqa: D400
Expand All @@ -115,14 +126,9 @@ def _create_in_db(self, db: Database):
self._annotation, type
), "Only composite data types can be created in database."
schema = "pg_temp"
members = get_type_hints(self._annotation)
if len(members) == 0:
raise Exception(f"Failed to get annotations for type {self._annotation}")
att_type_str = ",\n".join(
[f"{name} {_serialize_to_type(type_t, db)}" for name, type_t in members.items()]
)

db._execute(
f'CREATE TYPE "{schema}"."{self._name}" AS (\n' f"{att_type_str}\n" f");",
f'CREATE TYPE "{schema}"."{self._name}" AS {self._serialize(db=db)};',
has_results=False,
)
self._created_in_dbs.add(db)
Expand Down Expand Up @@ -178,7 +184,7 @@ def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = Non
return DataType(name, schema=schema, modifier=modifier)


def _serialize_to_type(
def _serialize_to_type_name(
annotation: Union[DataType, type],
db: Database,
for_return: bool = False,
Expand All @@ -204,10 +210,10 @@ def _serialize_to_type(
if annotation.__origin__ == list or annotation.__origin__ == List:
args: Tuple[type, ...] = annotation.__args__
if for_return:
return f"SETOF {_serialize_to_type(args[0], db)}" # type: ignore
if args[0] in _defined_types:
return f"{_serialize_to_type(args[0], db)}[]" # type: ignore
raise NotImplementedError()
return f"SETOF {_serialize_to_type_name(args[0], db)}" # type: ignore
else:
return f"{_serialize_to_type_name(args[0], db)}[]" # type: ignore
raise NotImplementedError("Only list is supported as generic data type")
else:
if isinstance(annotation, DataType):
return annotation._qualified_name_str
Expand All @@ -216,4 +222,5 @@ def _serialize_to_type(
type_name = "type_" + uuid4().hex
_defined_types[annotation] = DataType(name=type_name, annotation=annotation)
_defined_types[annotation]._create_in_db(db)
print(_defined_types)
return _defined_types[annotation]._qualified_name_str
23 changes: 23 additions & 0 deletions tests/test_plcontainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dataclasses import dataclass
import greenplumpython as gp

from tests import db


def test_simple_func(db: gp.Database):
@dataclass
class Int:
i: int

@gp.create_function(language_handler="plcontainer", runtime="plc_python_example")
def add_one(x: list[Int]) -> list[Int]:
return [{"i": arg["i"] + 1} for arg in x]

assert (
len(
list(
db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True)
)
)
== 10
)
6 changes: 3 additions & 3 deletions tests/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

import greenplumpython as gp
from greenplumpython.type import _serialize_to_type
from greenplumpython.type import _serialize_to_type_name
from tests import db


Expand Down Expand Up @@ -76,7 +76,7 @@ class Person:
_first_name: str
_last_name: str

type_name = _serialize_to_type(Person, db=db)
type_name = _serialize_to_type_name(Person, db=db)
assert isinstance(type_name, str)


Expand All @@ -88,5 +88,5 @@ def __init__(self, _first_name: str, _last_name: str) -> None:
self._last_name = _last_name

with pytest.raises(Exception) as exc_info:
_serialize_to_type(Person, db=db)
_serialize_to_type_name(Person, db=db)
assert "Failed to get annotations" in str(exc_info.value)