From ec8cb733b9bf6d15e0d4ac36a961727ab28b238f Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 22 Feb 2023 01:12:29 +0100 Subject: [PATCH] Typing: Use modern syntax for `aiida.engine.processes.functions` Use the new union syntax from PEP 604. --- .pre-commit-config.yaml | 1 - aiida/engine/processes/functions.py | 35 +++++++++++++++-------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index be5aaa3cc9..4c9f688f07 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -110,7 +110,6 @@ repos: aiida/engine/processes/calcjobs/monitors.py| aiida/engine/processes/calcjobs/tasks.py| aiida/engine/processes/control.py| - aiida/engine/processes/functions.py| aiida/engine/processes/ports.py| aiida/manage/configuration/__init__.py| aiida/manage/configuration/config.py| diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 2a2dd7f610..69dbeebb9a 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -15,7 +15,8 @@ import inspect import logging import signal -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar +import typing as t +from typing import TYPE_CHECKING from aiida.common.lang import override from aiida.manage import get_manager @@ -31,7 +32,7 @@ LOGGER = logging.getLogger(__name__) -FunctionType = TypeVar('FunctionType', bound=Callable[..., Any]) +FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any]) def calcfunction(function: FunctionType) -> FunctionType: @@ -88,14 +89,14 @@ def workfunction(function: FunctionType) -> FunctionType: return process_function(node_class=WorkFunctionNode)(function) -def process_function(node_class: Type['ProcessNode']) -> Callable[[Callable[..., Any]], Callable[..., Any]]: +def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]: """ The base function decorator to create a FunctionProcess out of a normal python function. :param node_class: the ORM class to be used as the Node record for the FunctionProcess """ - def decorator(function: Callable[..., Any]) -> Callable[..., Any]: + def decorator(function: FunctionType) -> FunctionType: """ Turn the decorated function into a FunctionProcess. @@ -104,7 +105,7 @@ def decorator(function: Callable[..., Any]) -> Callable[..., Any]: """ process_class = FunctionProcess.build(function, node_class=node_class) - def run_get_node(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], 'ProcessNode']: + def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode']: """ Run the FunctionProcess with the supplied inputs in a local runner. @@ -159,7 +160,7 @@ def kill_process(_num, _frame): return result, process.node - def run_get_pk(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], int]: + def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]: """Recreate the `run_get_pk` utility launcher. :param args: input arguments to construct the FunctionProcess @@ -185,7 +186,7 @@ def decorated_function(*args, **kwargs): decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined] decorated_function.spec = process_class.spec # type: ignore[attr-defined] - return decorated_function + return decorated_function # type: ignore[return-value] return decorator @@ -193,7 +194,7 @@ def decorated_function(*args, **kwargs): class FunctionProcess(Process): """Function process class used for turning functions into a Process""" - _func_args: Sequence[str] = () + _func_args: t.Sequence[str] = () _varargs: str | None = None @staticmethod @@ -205,7 +206,7 @@ def _func(*_args, **_kwargs) -> dict: return {} @staticmethod - def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['FunctionProcess']: + def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['FunctionProcess']: """ Build a Process from the given function. @@ -274,7 +275,7 @@ def _define(cls, spec): # pylint: disable=unused-argument ): indirect_default = lambda value=default: to_aiida_type(value) else: - indirect_default = default # type: ignore[assignment] + indirect_default = default spec.input(parameter.name, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type) @@ -306,7 +307,7 @@ def _define(cls, spec): # pylint: disable=unused-argument ) @classmethod - def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument + def validate_inputs(cls, *args: t.Any, **kwargs: t.Any) -> None: # pylint: disable=unused-argument """ Validate the positional and keyword arguments passed in the function call. @@ -327,7 +328,7 @@ def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable= raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given') @classmethod - def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def create_inputs(cls, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]: """Create the input args for the FunctionProcess.""" cls.validate_inputs(*args, **kwargs) @@ -339,7 +340,7 @@ def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: return ins @classmethod - def args_to_dict(cls, *args: Any) -> Dict[str, Any]: + def args_to_dict(cls, *args: t.Any) -> dict[str, t.Any]: """ Create an input dictionary (of form label -> value) from supplied args. @@ -388,7 +389,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore @property - def process_class(self) -> Callable[..., Any]: + def process_class(self) -> t.Callable[..., t.Any]: """ Return the class that represents this Process, for the FunctionProcess this is the function itself. @@ -401,7 +402,7 @@ class that really represents what was being executed. """ return self._func - def execute(self) -> Optional[Dict[str, Any]]: + def execute(self) -> dict[str, t.Any] | None: """Execute the process.""" result = super().execute() @@ -418,7 +419,7 @@ def _setup_db_record(self) -> None: self.node.store_source_info(self._func) @override - def run(self) -> Optional['ExitCode']: + def run(self) -> 'ExitCode' | None: """Run the process.""" from .exit_code import ExitCode @@ -427,7 +428,7 @@ def run(self) -> Optional['ExitCode']: # been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other # than `None`, it should mean this node was taken from the cache, so the process should not be rerun. if self.node.exit_status is not None: - return self.node.exit_status + return ExitCode(self.node.exit_status, self.node.exit_message) # Split the inputs into positional and keyword arguments args = [None] * len(self._func_args)