Skip to content

Commit

Permalink
Typing: Use modern syntax for aiida.engine.processes.functions
Browse files Browse the repository at this point in the history
Use the new union syntax from PEP 604.
  • Loading branch information
sphuber committed Mar 16, 2023
1 parent e04a936 commit ec8cb73
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ repos:
aiida/engine/processes/calcjobs/monitors.py|
aiida/engine/processes/calcjobs/tasks.py|
aiida/engine/processes/control.py|
aiida/engine/processes/functions.py|
aiida/engine/processes/ports.py|
aiida/manage/configuration/__init__.py|
aiida/manage/configuration/config.py|
Expand Down
35 changes: 18 additions & 17 deletions aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import inspect
import logging
import signal
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar
import typing as t
from typing import TYPE_CHECKING

from aiida.common.lang import override
from aiida.manage import get_manager
Expand All @@ -31,7 +32,7 @@

LOGGER = logging.getLogger(__name__)

FunctionType = TypeVar('FunctionType', bound=Callable[..., Any])
FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any])


def calcfunction(function: FunctionType) -> FunctionType:
Expand Down Expand Up @@ -88,14 +89,14 @@ def workfunction(function: FunctionType) -> FunctionType:
return process_function(node_class=WorkFunctionNode)(function)


def process_function(node_class: Type['ProcessNode']) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]:
"""
The base function decorator to create a FunctionProcess out of a normal python function.
:param node_class: the ORM class to be used as the Node record for the FunctionProcess
"""

def decorator(function: Callable[..., Any]) -> Callable[..., Any]:
def decorator(function: FunctionType) -> FunctionType:
"""
Turn the decorated function into a FunctionProcess.
Expand All @@ -104,7 +105,7 @@ def decorator(function: Callable[..., Any]) -> Callable[..., Any]:
"""
process_class = FunctionProcess.build(function, node_class=node_class)

def run_get_node(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], 'ProcessNode']:
def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode']:
"""
Run the FunctionProcess with the supplied inputs in a local runner.
Expand Down Expand Up @@ -159,7 +160,7 @@ def kill_process(_num, _frame):

return result, process.node

def run_get_pk(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], int]:
def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]:
"""Recreate the `run_get_pk` utility launcher.
:param args: input arguments to construct the FunctionProcess
Expand All @@ -185,15 +186,15 @@ def decorated_function(*args, **kwargs):
decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined]
decorated_function.spec = process_class.spec # type: ignore[attr-defined]

return decorated_function
return decorated_function # type: ignore[return-value]

return decorator


class FunctionProcess(Process):
"""Function process class used for turning functions into a Process"""

_func_args: Sequence[str] = ()
_func_args: t.Sequence[str] = ()
_varargs: str | None = None

@staticmethod
Expand All @@ -205,7 +206,7 @@ def _func(*_args, **_kwargs) -> dict:
return {}

@staticmethod
def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['FunctionProcess']:
def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['FunctionProcess']:
"""
Build a Process from the given function.
Expand Down Expand Up @@ -274,7 +275,7 @@ def _define(cls, spec): # pylint: disable=unused-argument
):
indirect_default = lambda value=default: to_aiida_type(value)
else:
indirect_default = default # type: ignore[assignment]
indirect_default = default

spec.input(parameter.name, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type)

Expand Down Expand Up @@ -306,7 +307,7 @@ def _define(cls, spec): # pylint: disable=unused-argument
)

@classmethod
def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument
def validate_inputs(cls, *args: t.Any, **kwargs: t.Any) -> None: # pylint: disable=unused-argument
"""
Validate the positional and keyword arguments passed in the function call.
Expand All @@ -327,7 +328,7 @@ def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=
raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given')

@classmethod
def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
def create_inputs(cls, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]:
"""Create the input args for the FunctionProcess."""
cls.validate_inputs(*args, **kwargs)

Expand All @@ -339,7 +340,7 @@ def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return ins

@classmethod
def args_to_dict(cls, *args: Any) -> Dict[str, Any]:
def args_to_dict(cls, *args: t.Any) -> dict[str, t.Any]:
"""
Create an input dictionary (of form label -> value) from supplied args.
Expand Down Expand Up @@ -388,7 +389,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore

@property
def process_class(self) -> Callable[..., Any]:
def process_class(self) -> t.Callable[..., t.Any]:
"""
Return the class that represents this Process, for the FunctionProcess this is the function itself.
Expand All @@ -401,7 +402,7 @@ class that really represents what was being executed.
"""
return self._func

def execute(self) -> Optional[Dict[str, Any]]:
def execute(self) -> dict[str, t.Any] | None:
"""Execute the process."""
result = super().execute()

Expand All @@ -418,7 +419,7 @@ def _setup_db_record(self) -> None:
self.node.store_source_info(self._func)

@override
def run(self) -> Optional['ExitCode']:
def run(self) -> 'ExitCode' | None:
"""Run the process."""
from .exit_code import ExitCode

Expand All @@ -427,7 +428,7 @@ def run(self) -> Optional['ExitCode']:
# been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other
# than `None`, it should mean this node was taken from the cache, so the process should not be rerun.
if self.node.exit_status is not None:
return self.node.exit_status
return ExitCode(self.node.exit_status, self.node.exit_message)

# Split the inputs into positional and keyword arguments
args = [None] * len(self._func_args)
Expand Down

0 comments on commit ec8cb73

Please sign in to comment.