diff --git a/CHANGELOG.md b/CHANGELOG.md index d1bc07451..c9dfa7c01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 and `XarrayIO.to_netcdf_file()` ([#314](https://github.com/Open-EO/openeo-python-client/issues/314)) - Changed argument name of `Connection.describe_collection()` from `name` to `collection_id` to be more in line with other methods/functions. +- Rework and improve `openeo.UDF` helper class for UDF usage: allow loading from local file, + autodetect `runtime` from source code, ensure proper `from_parameter` value ([#312](https://github.com/Open-EO/openeo-python-client/issues/312)) ### Fixed diff --git a/openeo/__init__.py b/openeo/__init__.py index 6d1363fe8..4705c1df2 100644 --- a/openeo/__init__.py +++ b/openeo/__init__.py @@ -13,10 +13,9 @@ class BaseOpenEoException(Exception): from openeo._version import __version__ from openeo.imagecollection import ImageCollection -from openeo.rest.datacube import DataCube +from openeo.rest.datacube import DataCube, UDF from openeo.rest.connection import connect, session, Connection from openeo.rest.job import BatchJob, RESTJob -from openeo.internal.graph_building import UDF def client_version() -> str: diff --git a/openeo/internal/graph_building.py b/openeo/internal/graph_building.py index d4de7a484..a6f88c935 100644 --- a/openeo/internal/graph_building.py +++ b/openeo/internal/graph_building.py @@ -156,26 +156,6 @@ def as_flat_graph(x: Union[dict, Any]) -> dict: raise ValueError(x) -class UDF(PGNode): - """ - A 'run_udf' process graph node. This is offered as a convenient way to construct run_udf processes. - """ - - def __init__(self, code: str, runtime: str, data, version: str = None, context: Dict = None): - arguments = { - "data": data, - "udf": code, - "runtime": runtime - } - if version is not None: - arguments["version"] = version - - if context is not None: - arguments["context"] = context - - super().__init__(process_id='run_udf', arguments=arguments) - - class ReduceNode(PGNode): """ A process graph node for "reduce" processes (has a reducer sub-process-graph) diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index 148988193..fc2e9d77b 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -10,6 +10,7 @@ import datetime import logging import pathlib +import re import typing import warnings from builtins import staticmethod @@ -48,6 +49,88 @@ log = logging.getLogger(__name__) +class UDF: + """ + Helper class to load UDF code (e.g. from file) and embed them as "callback" or child process in a process graph. + + .. versionadded:: 0.11.0 + """ + __slots__ = ["code", "runtime", "version", "context", "_source"] + + def __init__( + self, code: str, runtime: Optional[str] = None, version: Optional[str] = None, + context: Optional[dict] = None, _source=None, + ): + """ + Construct a UDF object from given code string and other ``run_udf`` related arguments + :param code: UDF source code string (Python, R, ...) + :param runtime: optional UDF runtime identifier, will be autodetected from source code if omitted. + :param version: optional UDF runtime version string + :param context: optional additional UDF context data + :param _source: (for internal use) source identifier + """ + # TODO: automatically dedent code (when literal string) ? + self.code = code + self.runtime = runtime + self.version = version + self.context = context + self._source = _source + + @classmethod + def from_file( + cls, path: Union[str, pathlib.Path], runtime: Optional[str] = None, version: Optional[str] = None, + context: Optional[dict] = None + ): + """ + Load a UDF from a local file. + + :param path: path to the local file with UDF source code + :param runtime: optional UDF runtime identifier, will be autodetected from source code if omitted. + :param version: optional UDF runtime version string + :param context: optional additional UDF context data + :return: + """ + path = pathlib.Path(path) + code = path.read_text(encoding="utf-8") + return cls(code=code, runtime=runtime, version=version, context=context, _source=path) + + def _guess_runtime(self, connection: "openeo.Connection") -> str: + """Guess UDF runtime from UDF source (path) or source code.""" + # First, guess UDF language + language = None + if isinstance(self._source, pathlib.Path): + language = { + ".py": "Python", + ".r": "R", + }.get(self._source.suffix.lower()) + if not language: + # Guess language from UDF code + if re.search("^def [\w0-9_]+\(", self.code, flags=re.MULTILINE): + language = "Python" + # TODO: detection heuristics for R and other languages? + if not language: + raise OpenEoClientException("Failed to detect language of UDF code.") + # Find runtime for language + runtimes = {k.lower(): k for k in connection.list_udf_runtimes().keys()} + if language.lower() in runtimes: + return runtimes[language.lower()] + else: + raise OpenEoClientException(f"Failed to match UDF language {language!r} with a runtime ({runtimes})") + + def get_run_udf_callback(self, connection: "openeo.Connection", data_parameter: str = "data") -> PGNode: + """ + For internal use: construct `run_udf` node to be used as callback in `apply`, `reduce_dimension`, ... + """ + arguments = dict_no_none( + data={"from_parameter": data_parameter}, + udf=self.code, + runtime=self.runtime or self._guess_runtime(connection=connection), + version=self.version, + context=self.context, + ) + return PGNode(process_id="run_udf", arguments=arguments) + + class DataCube(_ProcessGraphAbstraction): """ Class representing a openEO (raster) data cube. @@ -789,7 +872,11 @@ def aggregate_spatial( ) @staticmethod - def _get_callback(process: Union[str, PGNode, typing.Callable], parent_parameters: List[str]) -> dict: + def _get_callback( + process: Union[str, PGNode, typing.Callable, UDF], + parent_parameters: List[str], + connection: Optional["openeo.Connection"] = None, + ) -> dict: """ Build a "callback" process: a user defined process that is used by another process (such as `apply`, `apply_dimension`, `reduce`, ....) @@ -820,6 +907,8 @@ def _get_callback(process: Union[str, PGNode, typing.Callable], parent_parameter pg = PGNode(process_id=process, arguments=arguments) elif isinstance(process, typing.Callable): pg = convert_callable_to_pgnode(process, parent_parameters=parent_parameters) + elif isinstance(process, UDF): + pg = process.get_run_udf_callback(connection=connection, data_parameter=parent_parameters[0]) else: raise ValueError(process) @@ -828,7 +917,7 @@ def _get_callback(process: Union[str, PGNode, typing.Callable], parent_parameter @openeo_process def apply_dimension( self, code: str = None, runtime=None, - process: Union[str, PGNode, typing.Callable] = None, + process: Union[str, PGNode, typing.Callable, UDF] = None, version="latest", # TODO: dimension has no default (per spec)? dimension="t", @@ -873,7 +962,9 @@ def apply_dimension( process = PGNode.to_process_graph_argument(callback_process_node) elif code or process: # TODO EP-3555 unify `code` and `process` - process = self._get_callback(code or process, parent_parameters=["data", "context"]) + process = self._get_callback( + process=code or process, parent_parameters=["data", "context"], connection=self.connection + ) else: raise OpenEoClientException("No UDF code or process given") arguments = { @@ -892,7 +983,8 @@ def apply_dimension( @openeo_process def reduce_dimension( self, - dimension: str, reducer: Union[str, PGNode, typing.Callable], + dimension: str, + reducer: Union[str, PGNode, typing.Callable, UDF], context: Optional[dict] = None, process_id="reduce_dimension", band_math_mode: bool = False ) -> "DataCube": @@ -905,7 +997,9 @@ def reduce_dimension( """ # TODO: check if dimension is valid according to metadata? #116 # TODO: #125 use/test case for `reduce_dimension_binary`? - reducer = self._get_callback(reducer, parent_parameters=["data", "context"]) + reducer = self._get_callback( + process=reducer, parent_parameters=["data", "context"], connection=self.connection + ) return self.process_with_node(ReduceNode( process_id=process_id, @@ -1035,7 +1129,7 @@ def reduce_temporal_udf(self, code: str, runtime="Python", version="latest"): @openeo_process def apply_neighborhood( self, - process: Union[str, PGNode, typing.Callable], + process: Union[str, PGNode, typing.Callable, UDF], size: List[Dict], overlap: List[dict] = None, context: Optional[dict] = None, @@ -1064,7 +1158,7 @@ def apply_neighborhood( process_id='apply_neighborhood', arguments=dict_no_none( data=THIS, - process=self._get_callback(process, parent_parameters=["data"]), + process=self._get_callback(process=process, parent_parameters=["data"], connection=self.connection), size=size, overlap=overlap, context=context, @@ -1072,7 +1166,7 @@ def apply_neighborhood( ) @openeo_process - def apply(self, process: Union[str, PGNode, typing.Callable] = None, context: Optional[dict] = None) -> 'DataCube': + def apply(self, process: Union[str, PGNode, typing.Callable, UDF] = None, context: Optional[dict] = None) -> 'DataCube': """ Applies a unary process (a local operation) to each value of the specified or all dimensions in the data cube. @@ -1086,7 +1180,7 @@ def apply(self, process: Union[str, PGNode, typing.Callable] = None, context: Op process_id="apply", arguments=dict_no_none({ "data": THIS, - "process": self._get_callback(process, parent_parameters=["x"]), + "process": self._get_callback(process, parent_parameters=["x"], connection=self.connection), "context": context, }) ) diff --git a/tests/rest/datacube/test_datacube100.py b/tests/rest/datacube/test_datacube100.py index 43955dc2e..984be934d 100644 --- a/tests/rest/datacube/test_datacube100.py +++ b/tests/rest/datacube/test_datacube100.py @@ -6,8 +6,6 @@ import collections import io import pathlib -import re -import sys import textwrap from typing import Optional @@ -22,7 +20,7 @@ from openeo.internal.process_graph_visitor import ProcessGraphVisitException from openeo.rest import OpenEoClientException from openeo.rest.connection import Connection -from openeo.rest.datacube import THIS, DataCube, ProcessBuilder +from openeo.rest.datacube import THIS, DataCube, ProcessBuilder, UDF from .conftest import API_URL, setup_collection_metadata from ... import load_json_resource @@ -2139,3 +2137,230 @@ def test_legacy_send_job(self, con100, requests_mock): with pytest.warns(DeprecationWarning, match="Call to deprecated method `send_job`, use `create_job` instead."): job = cube.send_job(out_format="GTiff") assert job.job_id == "myj0b1" + + +class TestUDF: + + def test_apply_udf_basic(self, con100): + udf = UDF("print('hello world')", runtime="Python") + cube = con100.load_collection("S2") + res = cube.apply(udf) + + assert res.flat_graph() == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "apply1": { + "process_id": "apply", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "process": { + "process_graph": {"runudf1": { + "process_id": "run_udf", + "arguments": { + "data": {"from_parameter": "x"}, + "runtime": "Python", + "udf": "print('hello world')", + }, + "result": True, + }}, + }, + }, + "result": True, + }, + } + + def test_apply_udf_runtime_detection(self, con100, requests_mock): + requests_mock.get(API_URL + "/udf_runtimes", json={ + "Python": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}}, + }) + + udf = UDF("def foo(x):\n return x\n") + cube = con100.load_collection("S2") + res = cube.apply(udf) + + assert res.flat_graph()["apply1"]["arguments"]["process"] == { + "process_graph": {"runudf1": { + "process_id": "run_udf", + "arguments": { + "data": {"from_parameter": "x"}, + "runtime": "Python", + "udf": "def foo(x):\n return x\n", + }, + "result": True, + }}, + } + + @pytest.mark.parametrize(["filename", "udf_code", "expected_runtime"], [ + ("udf-code.py", "def foo(x):\n return x\n", "Python"), + ("udf-code.py", "# just empty, but at least with `.py` suffix\n", "Python"), + ("udf-code-py.txt", "def foo(x):\n return x\n", "Python"), + ("udf-code.r", "# R code here\n", "R"), + ]) + def test_apply_udf_load_from_file(self, con100, requests_mock, tmp_path, filename, udf_code, expected_runtime): + requests_mock.get(API_URL + "/udf_runtimes", json={ + "Python": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}}, + "R": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}}, + }) + + path = tmp_path / filename + path.write_text(udf_code) + + udf = UDF.from_file(path) + cube = con100.load_collection("S2") + res = cube.apply(udf) + + assert res.flat_graph()["apply1"]["arguments"]["process"] == { + "process_graph": {"runudf1": { + "process_id": "run_udf", + "arguments": { + "data": {"from_parameter": "x"}, + "runtime": expected_runtime, + "udf": udf_code, + }, + "result": True, + }}, + } + + @pytest.mark.parametrize(["kwargs"], [ + ({"version": "3.8"},), + ({"context": {"color": "red"}},), + ]) + def test_apply_udf_version_and_context(self, con100, requests_mock, kwargs): + requests_mock.get(API_URL + "/udf_runtimes", json={ + "Python": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}}, + }) + + udf = UDF("def foo(x):\n return x\n", **kwargs) + cube = con100.load_collection("S2") + res = cube.apply(udf) + + expected_args = { + "data": {"from_parameter": "x"}, + "runtime": "Python", + "udf": "def foo(x):\n return x\n", + } + expected_args.update(kwargs) + assert res.flat_graph()["apply1"]["arguments"]["process"] == { + "process_graph": {"runudf1": { + "process_id": "run_udf", + "arguments": expected_args, + "result": True, + }}, + } + + def test_simple_apply_udf(self, con100, requests_mock): + requests_mock.get(API_URL + "/udf_runtimes", json={ + "Python": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}}, + }) + + udf = UDF("def foo(x):\n return x\n") + cube = con100.load_collection("S2") + res = cube.apply(udf) + + assert res.flat_graph()["apply1"] == { + "process_id": "apply", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "process": { + "process_graph": {"runudf1": { + "process_id": "run_udf", + "arguments": { + "data": {"from_parameter": "x"}, + "runtime": "Python", + "udf": "def foo(x):\n return x\n", + }, + "result": True, + }}, + }, + }, + "result": True, + } + + def test_simple_apply_dimension_udf(self, con100, requests_mock): + requests_mock.get(API_URL + "/udf_runtimes", json={ + "Python": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}}, + }) + + udf = UDF("def foo(x):\n return x\n") + cube = con100.load_collection("S2") + res = cube.apply_dimension(process=udf, dimension="t") + + assert res.flat_graph()["applydimension1"] == { + "process_id": "apply_dimension", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "dimension": "t", + "process": { + "process_graph": {"runudf1": { + "process_id": "run_udf", + "arguments": { + "data": {"from_parameter": "data"}, + "runtime": "Python", + "udf": "def foo(x):\n return x\n", + }, + "result": True, + }}, + }, + }, + "result": True, + } + + def test_simple_reduce_dimension_udf(self, con100, requests_mock): + requests_mock.get(API_URL + "/udf_runtimes", json={ + "Python": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}}, + }) + + udf = UDF("def foo(x):\n return x\n") + cube = con100.load_collection("S2") + res = cube.reduce_dimension(reducer=udf, dimension="t") + + assert res.flat_graph()["reducedimension1"] == { + "process_id": "reduce_dimension", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "dimension": "t", + "reducer": { + "process_graph": {"runudf1": { + "process_id": "run_udf", + "arguments": { + "data": {"from_parameter": "data"}, + "runtime": "Python", + "udf": "def foo(x):\n return x\n", + }, + "result": True, + }}, + }, + }, + "result": True, + } + + def test_simple_apply_neighborhood_udf(self, con100, requests_mock): + requests_mock.get(API_URL + "/udf_runtimes", json={ + "Python": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}}, + }) + + udf = UDF("def foo(x):\n return x\n") + cube = con100.load_collection("S2") + res = cube.apply_neighborhood(process=udf, size=27) + + assert res.flat_graph()["applyneighborhood1"] == { + "process_id": "apply_neighborhood", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "size": 27, + "process": { + "process_graph": {"runudf1": { + "process_id": "run_udf", + "arguments": { + "data": {"from_parameter": "data"}, + "runtime": "Python", + "udf": "def foo(x):\n return x\n", + }, + "result": True, + }}, + }, + }, + "result": True, + }