Skip to content
This repository was archived by the owner on Jul 16, 2024. It is now read-only.

Make it work with PL/Container #223

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 90 additions & 50 deletions greenplumpython/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from greenplumpython.db import Database
from greenplumpython.expr import Expr, _serialize_to_expr
from greenplumpython.group import DataFrameGroupingSet
from greenplumpython.type import _serialize_to_type
from greenplumpython.type import _defined_types, _serialize_to_type_name


class FunctionExpr(Expr):
Expand Down Expand Up @@ -111,52 +111,79 @@ def apply(
if grouping_col_names is not None and len(grouping_col_names) != 0
else None
)
unexpanded_dataframe = DataFrame(
" ".join(
if (
isinstance(self._function, NormalFunction)
and self._function._language_handler == "plcontainer"
):
return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType
_serialize_to_type_name(return_annotation, db=db, for_return=True)
input_args = self._args
if len(input_args) == 0:
raise Exception("No input data specified, please specify a DataFrame or Columns")
input_clause = (
"*"
if (len(input_args) == 1 and isinstance(input_args[0], DataFrame))
else ",".join([arg._serialize(db=db) for arg in input_args])
)
return DataFrame(
f"""
SELECT * FROM plcontainer_apply(TABLE(
SELECT {input_clause} {from_clause}), '{self._function._qualified_name_str}', 4096) AS
{_defined_types[return_annotation.__args__[0]]._serialize(db=db)}
""",
db=db,
parents=parents,
)
else:
unexpanded_dataframe = DataFrame(
" ".join(
[
f"SELECT {_serialize_to_expr(self, db=db)} {'AS ' + column_name if column_name is not None else ''}",
("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "",
from_clause,
group_by_clause,
]
),
db=db,
parents=parents,
)
# We use 2 `DataFrame`s here because on GPDB 6X and PostgreSQL <= 9.6, a
# function returning records that contains more than one attributes
# will be called multiple times if we do
# ```sql
# SELECT (func(a, b)).* FROM t;
# ```
# which might cause performance issue. To workaround we need to do
# ```sql
# WITH func_call AS (
# SELECT func(a, b) AS result FROM t
# )
# SELECT (result).* FROM func_call;
# ```
rebased_grouping_cols = (
[
f"SELECT {_serialize_to_expr(self, db=db)} {'AS ' + column_name if column_name is not None else ''}",
("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "",
from_clause,
group_by_clause,
_serialize_to_expr(unexpanded_dataframe[name], db=db)
for name in grouping_col_names
]
),
db=db,
parents=parents,
)
# We use 2 `DataFrame`s here because on GPDB 6X and PostgreSQL <= 9.6, a
# function returning records that contains more than one attributes
# will be called multiple times if we do
# ```sql
# SELECT (func(a, b)).* FROM t;
# ```
# which might cause performance issue. To workaround we need to do
# ```sql
# WITH func_call AS (
# SELECT func(a, b) AS result FROM t
# )
# SELECT (result).* FROM func_call;
# ```
rebased_grouping_cols = (
[_serialize_to_expr(unexpanded_dataframe[name], db=db) for name in grouping_col_names]
if grouping_col_names is not None
else None
)
result_cols = (
_serialize_to_expr(unexpanded_dataframe["*"], db=db)
if not expand
else _serialize_to_expr(unexpanded_dataframe[column_name]["*"], db=db)
# `len(rebased_grouping_cols) == 0` means `GROUP BY GROUPING SETS (())`
if rebased_grouping_cols is None or len(rebased_grouping_cols) == 0
else f"({unexpanded_dataframe._name}).*"
if not expand
else f"{','.join(rebased_grouping_cols)}, {_serialize_to_expr(unexpanded_dataframe[column_name]['*'], db=db)}"
)
if grouping_col_names is not None
else None
)
result_cols = (
_serialize_to_expr(unexpanded_dataframe["*"], db=db)
if not expand
else _serialize_to_expr(unexpanded_dataframe[column_name]["*"], db=db)
# `len(rebased_grouping_cols) == 0` means `GROUP BY GROUPING SETS (())`
if rebased_grouping_cols is None or len(rebased_grouping_cols) == 0
else f"({unexpanded_dataframe._name}).*"
if not expand
else f"{','.join(rebased_grouping_cols)}, {_serialize_to_expr(unexpanded_dataframe[column_name]['*'], db=db)}"
)

return DataFrame(
f"SELECT {result_cols} FROM {unexpanded_dataframe._name}",
db=db,
parents=[unexpanded_dataframe],
)
return DataFrame(
f"SELECT {result_cols} FROM {unexpanded_dataframe._name}",
db=db,
parents=[unexpanded_dataframe],
)

@property
def _function(self) -> "_AbstractFunction":
Expand Down Expand Up @@ -272,12 +299,14 @@ def __init__(
name: Optional[str] = None,
schema: Optional[str] = None,
language_handler: Literal["plpython3u"] = "plpython3u",
runtime: Optional[str] = None,
) -> None:
# noqa D107
super().__init__(wrapped_func, name, schema)
self._created_in_dbs: Optional[Set[Database]] = set() if wrapped_func is not None else None
self._wrapped_func = wrapped_func
self._language_handler = language_handler
self._runtime = runtime

def unwrap(self) -> Callable[..., Any]:
"""Get the wrapped Python function in the database function."""
Expand All @@ -302,14 +331,18 @@ def _serialize(self, db: Database) -> str:
func_sig = inspect.signature(self._wrapped_func)
func_args = ",".join(
[
f'"{param.name}" {_serialize_to_type(param.annotation, db=db)}'
f'"{param.name}" {_serialize_to_type_name(param.annotation, db=db)}'
for param in func_sig.parameters.values()
]
)
func_arg_names = ",".join(
[f"{param.name}={param.name}" for param in func_sig.parameters.values()]
)
return_type = _serialize_to_type(func_sig.return_annotation, db=db, for_return=True)
return_type = (
_serialize_to_type_name(func_sig.return_annotation, db=db, for_return=True)
if self._language_handler != "plcontainer"
else "SETOF record"
)
func_pickled: bytes = dill.dumps(self._wrapped_func)
_, func_name = self._qualified_name
# Modify the AST of the wrapped function to minify dependency: (1-3)
Expand All @@ -335,6 +368,7 @@ def _serialize(self, db: Database) -> str:
f"CREATE FUNCTION {self._qualified_name_str} ({func_args}) "
f"RETURNS {return_type} "
f"AS $$\n"
f"# container: {self._runtime}\n"
f"try:\n"
f" return GD['{func_ast.name}']({func_arg_names})\n"
f"except KeyError:\n"
Expand All @@ -344,6 +378,7 @@ def _serialize(self, db: Database) -> str:
f" import sys as {sys_lib_name}\n"
f" if {sysconfig_lib_name}.get_python_version() != '{python_version}':\n"
f" raise ModuleNotFoundError\n"
f" {sys_lib_name}.modules['plpy']=plpy\n"
f" setattr({sys_lib_name}.modules['plpy'], '_SD', SD)\n"
f" GD['{func_ast.name}'] = {pickle_lib_name}.loads({func_pickled})\n"
f" except ModuleNotFoundError:\n"
Expand Down Expand Up @@ -461,7 +496,7 @@ def _create_in_db(self, db: Database) -> None:
state_param = next(param_list)
args_string = ",".join(
[
f"{param.name} {_serialize_to_type(param.annotation, db=db)}"
f"{param.name} {_serialize_to_type_name(param.annotation, db=db)}"
for param in param_list
]
)
Expand All @@ -470,7 +505,7 @@ def _create_in_db(self, db: Database) -> None:
(
f"CREATE AGGREGATE {self._qualified_name_str} ({args_string}) (\n"
f" SFUNC = {self.transition_function._qualified_name_str},\n"
f" STYPE = {_serialize_to_type(state_param.annotation, db=db)}\n"
f" STYPE = {_serialize_to_type_name(state_param.annotation, db=db)}\n"
f");\n"
),
has_results=False,
Expand Down Expand Up @@ -547,6 +582,7 @@ def aggregate_function(name: str, schema: Optional[str] = None) -> AggregateFunc
def create_function(
wrapped_func: Optional[Callable[..., Any]] = None,
language_handler: Literal["plpython3u"] = "plpython3u",
runtime: Optional[str] = None,
) -> NormalFunction:
"""
Create a :class:`~func.NormalFunction` from the given Python function.
Expand Down Expand Up @@ -610,8 +646,12 @@ def create_function(
"""
# If user needs extra parameters when creating a function
if wrapped_func is None:
return functools.partial(create_function, language_handler=language_handler)
return NormalFunction(wrapped_func=wrapped_func, language_handler=language_handler)
return functools.partial(
create_function, language_handler=language_handler, runtime=runtime
)
return NormalFunction(
wrapped_func=wrapped_func, language_handler=language_handler, runtime=runtime
)


# FIXME: Add test cases for optional parameters
Expand Down
30 changes: 18 additions & 12 deletions greenplumpython/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ def __init__(
if self._modifier is not None:
self._qualified_name_str += f"({self._modifier})"

def _serialize(self, db: Database) -> str:
if self._annotation is None:
raise Exception("No type annotation to serialize")
members = get_type_hints(self._annotation)
if len(members) == 0:
raise Exception(f"Failed to get annotations for type {self._annotation}")
members_str = ",\n".join(
[f"{name} {_serialize_to_type_name(type_t, db)}" for name, type_t in members.items()]
)
return f"({members_str})"

# -- Creation of a composite type in Greenplum corresponding to the class_type given
def _create_in_db(self, db: Database):
# noqa: D400
Expand All @@ -115,14 +126,9 @@ def _create_in_db(self, db: Database):
self._annotation, type
), "Only composite data types can be created in database."
schema = "pg_temp"
members = get_type_hints(self._annotation)
if len(members) == 0:
raise Exception(f"Failed to get annotations for type {self._annotation}")
att_type_str = ",\n".join(
[f"{name} {_serialize_to_type(type_t, db)}" for name, type_t in members.items()]
)

db._execute(
f'CREATE TYPE "{schema}"."{self._name}" AS (\n' f"{att_type_str}\n" f");",
f'CREATE TYPE "{schema}"."{self._name}" AS {self._serialize(db=db)};',
has_results=False,
)
self._created_in_dbs.add(db)
Expand Down Expand Up @@ -178,7 +184,7 @@ def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = Non
return DataType(name, schema=schema, modifier=modifier)


def _serialize_to_type(
def _serialize_to_type_name(
annotation: Union[DataType, type],
db: Database,
for_return: bool = False,
Expand All @@ -204,10 +210,10 @@ def _serialize_to_type(
if annotation.__origin__ == list or annotation.__origin__ == List:
args: Tuple[type, ...] = annotation.__args__
if for_return:
return f"SETOF {_serialize_to_type(args[0], db)}" # type: ignore
if args[0] in _defined_types:
return f"{_serialize_to_type(args[0], db)}[]" # type: ignore
raise NotImplementedError()
return f"SETOF {_serialize_to_type_name(args[0], db)}" # type: ignore
else:
return f"{_serialize_to_type_name(args[0], db)}[]" # type: ignore
raise NotImplementedError("Only list is supported as generic data type")
else:
if isinstance(annotation, DataType):
return annotation._qualified_name_str
Expand Down
56 changes: 56 additions & 0 deletions tests/test_plcontainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from dataclasses import dataclass

import pytest

import greenplumpython as gp
from tests import db


@dataclass
class Int:
i: int


@dataclass
class Pair:
i: int
j: int


@pytest.fixture
def t(db: gp.Database):
rows = [(i, i) for i in range(10)]
return db.create_dataframe(rows=rows, column_names=["a", "b"])


@gp.create_function(language_handler="plcontainer", runtime="plc_python_example")
def add_one(x: list[Int]) -> list[Int]:
return [{"i": arg["i"] + 1} for arg in x]


def test_simple_func(db: gp.Database):
assert (
len(
list(
db.create_dataframe(columns={"i": range(10)}).apply(
lambda t: add_one(t), expand=True
)
)
)
== 10
)


def test_func_no_input(db: gp.Database):

with pytest.raises(Exception) as exc_info: # no input data for func raises Exception
db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True)
assert "No input data specified, please specify a DataFrame or Columns" in str(exc_info.value)


def test_func_column(db: gp.Database, t: gp.DataFrame):
@gp.create_function(language_handler="plcontainer", runtime="plc_python_example")
def add(x: list[Pair]) -> list[Int]:
return [{"i": arg["i"] + arg["j"]} for arg in x]

assert len(list(t.apply(lambda t: add(t["a"], t["b"]), expand=True))) == 10
6 changes: 3 additions & 3 deletions tests/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

import greenplumpython as gp
from greenplumpython.type import _serialize_to_type
from greenplumpython.type import _serialize_to_type_name
from tests import db


Expand Down Expand Up @@ -76,7 +76,7 @@ class Person:
_first_name: str
_last_name: str

type_name = _serialize_to_type(Person, db=db)
type_name = _serialize_to_type_name(Person, db=db)
assert isinstance(type_name, str)


Expand All @@ -88,5 +88,5 @@ def __init__(self, _first_name: str, _last_name: str) -> None:
self._last_name = _last_name

with pytest.raises(Exception) as exc_info:
_serialize_to_type(Person, db=db)
_serialize_to_type_name(Person, db=db)
assert "Failed to get annotations" in str(exc_info.value)