From 5a5a699cc324ce532d199e8bbfe1d1a8cfa67f24 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 23 Aug 2024 13:18:45 +0200 Subject: [PATCH] Add typechecking workflows and add type annotations --- .github/workflows/ci.yml | 22 ++++++++ .pre-commit-config.yaml | 13 +++++ README.md | 2 +- environment.yml | 8 ++- parcels/_typing.py | 46 +++++++++++++++++ parcels/compilation/codegenerator.py | 6 +-- parcels/interaction/interactionkernel.py | 2 +- parcels/tools/converters.py | 8 +-- parcels/tools/global_statics.py | 2 +- parcels/tools/interpolation_utils.py | 66 ++++++++++++++++-------- parcels/tools/loggers.py | 4 +- parcels/tools/timer.py | 2 +- pyproject.toml | 22 ++++++++ 13 files changed, 168 insertions(+), 35 deletions(-) create mode 100644 parcels/_typing.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c0a62e3a..41b91bfe8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -81,3 +81,25 @@ jobs: with: name: Integration test report path: ${{ matrix.os }}_integration_test_report.html + typechecking: + name: mypy + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Conda and parcels + uses: ./.github/actions/install-parcels + with: + environment-file: environment.yml + - run: conda install lxml # dep for report generation + - name: Typechecking + run: | + mypy --install-types --non-interactive parcels --cobertura-xml-report mypy_report + - name: Upload mypy coverage to Codecov + uses: codecov/codecov-action@v3.1.1 + if: ${{ always() }} # Upload even on error of mypy + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: mypy_report/cobertura.xml + flags: mypy + fail_ci_if_error: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccc8a6e5f..eeae58c68 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,3 +24,16 @@ repos: rev: v0.4.0 hooks: - id: biome-format + + # Ruff doesn't have full coverage of pydoclint https://github.com/astral-sh/ruff/issues/12434 + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + name: pydoclint + files: 'none' + # files: parcels/fieldset.py # put here instead of in config file due to https://github.com/pre-commit/pre-commit-hooks/issues/112#issuecomment-215613842 + args: + - --select=DOC103 # TODO: Expand coverage to other codes + additional_dependencies: + - pydoclint[flake8] diff --git a/README.md b/README.md index 10e6ae196..511f8e5fc 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ ## Parcels [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/OceanParcels/parcels/master?labpath=docs%2Fexamples%2Fparcels_tutorial.ipynb) -[![unit-tests](https://github.com/OceanParcels/parcels/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/OceanParcels/parcels/actions/workflows/unit-tests.yml) +[![unit-tests](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml/badge.svg)](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml) [![codecov](https://codecov.io/gh/OceanParcels/parcels/branch/master/graph/badge.svg)](https://codecov.io/gh/OceanParcels/parcels) [![Anaconda-release](https://anaconda.org/conda-forge/parcels/badges/version.svg)](https://anaconda.org/conda-forge/parcels/) [![Anaconda-date](https://anaconda.org/conda-forge/parcels/badges/latest_release_date.svg)](https://anaconda.org/conda-forge/parcels/) diff --git a/environment.yml b/environment.yml index e244acca1..d096b3c3a 100644 --- a/environment.yml +++ b/environment.yml @@ -33,10 +33,14 @@ dependencies: - pytest-html - coverage + # Typing + - mypy + - types-tqdm + - types-psutil + + # Linting - - flake8>=2.1.0 - pre_commit - - pydocstyle # Docs - ipython diff --git a/parcels/_typing.py b/parcels/_typing.py new file mode 100644 index 000000000..637393a07 --- /dev/null +++ b/parcels/_typing.py @@ -0,0 +1,46 @@ +""" +Typing support for Parcels. + +This module contains type aliases used throughout Parcels as well as functions that are +used for runtime parameter validation (to ensure users are only using the right params). + +""" + +import ast +import datetime +import os +from typing import Any, Callable, Literal, get_args + + +class ParcelsAST(ast.AST): + ccode: str + + +# InterpMethod = InterpMethodOption | dict[str, InterpMethodOption] # (can also be a dict, search for `if type(interp_method) is dict`) +# InterpMethodOption = Literal[ +# "nearest", +# "freeslip", +# "partialslip", +# "bgrid_velocity", +# "bgrid_w_velocity", +# "cgrid_velocity", +# "linear_invdist_land_tracer", +# "nearest", +# "cgrid_tracer", +# ] # mostly corresponds with `interp_method` # TODO: This should be narrowed. Unlikely applies to every context +PathLike = str | os.PathLike +Mesh = Literal["spherical", "flat"] # mostly corresponds with `mesh` +VectorType = Literal["3D", "2D"] | None # mostly corresponds with `vector_type` +ChunkMode = Literal["auto", "specific", "failsafe"] # mostly corresponds with `chunk_mode` +GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # mostly corresponds with `grid_indexing_type` +UpdateStatus = Literal["not_updated", "first_updated", "updated"] # mostly corresponds with `update_status` +TimePeriodic = float | datetime.timedelta | Literal[False] # mostly corresponds with `update_status` + +KernelFunction = Callable[..., None] + + +def ensure_is_literal_value(value: Any, literal: Any) -> None: + """Ensures that a value is a valid option for the provided Literal type annotation.""" + valid_options = get_args(literal) + if value not in valid_options: + raise ValueError(f"{value!r} is not a valid option. Valid options are {valid_options}") diff --git a/parcels/compilation/codegenerator.py b/parcels/compilation/codegenerator.py index 2569dcc6c..921b422cb 100644 --- a/parcels/compilation/codegenerator.py +++ b/parcels/compilation/codegenerator.py @@ -410,7 +410,7 @@ class KernelGenerator(ABC, ast.NodeVisitor): # Intrinsic variables that appear as function arguments kernel_vars = ["particle", "fieldset", "time", "output_time", "tol"] - array_vars = [] + array_vars: list[str] = [] def __init__(self, fieldset=None, ptype=JITParticle): self.fieldset = fieldset @@ -419,7 +419,7 @@ def __init__(self, fieldset=None, ptype=JITParticle): self.vector_field_args = collections.OrderedDict() self.const_args = collections.OrderedDict() - def generate(self, py_ast, funcvars): + def generate(self, py_ast, funcvars: list[str]): # Replace occurrences of intrinsic objects in Python AST transformer = IntrinsicTransformer(self.fieldset, self.ptype) py_ast = transformer.visit(py_ast) @@ -434,7 +434,7 @@ def generate(self, py_ast, funcvars): # Insert variable declarations for non-intrinsic variables # Make sure that repeated variables are not declared more than # once. If variables occur in multiple Kernels, give a warning - used_vars = [] + used_vars: list[str] = [] funcvars_copy = copy(funcvars) # editing a list while looping over it is dangerous for kvar in funcvars: if kvar in used_vars + ["particle_dlon", "particle_dlat", "particle_ddepth"]: diff --git a/parcels/interaction/interactionkernel.py b/parcels/interaction/interactionkernel.py index db4c7f897..c07af94d6 100644 --- a/parcels/interaction/interactionkernel.py +++ b/parcels/interaction/interactionkernel.py @@ -36,7 +36,7 @@ def __init__( py_ast=None, funcvars=None, c_include="", - delete_cfiles=True, + delete_cfiles: bool = True, ): if MPI is not None and MPI.COMM_WORLD.Get_size() > 1: raise NotImplementedError( diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index c8ccf8003..17b7a5abf 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -2,10 +2,12 @@ import inspect from datetime import timedelta from math import cos, pi +from typing import Any import cftime import numpy as np import xarray as xr +from numpy.typing import ArrayLike, NDArray __all__ = [ "UnitConverter", @@ -20,7 +22,7 @@ ] -def convert_to_flat_array(var): +def convert_to_flat_array(var: list[float] | float | int | NDArray[Any] | ArrayLike) -> NDArray[Any]: """Convert lists and single integers/floats to one-dimensional numpy arrays Parameters @@ -167,8 +169,8 @@ def __le__(self, other): class UnitConverter: """Interface class for spatial unit conversion during field sampling that performs no conversion.""" - source_unit = None - target_unit = None + source_unit: str | None = None + target_unit: str | None = None def to_target(self, value, x, y, z): return value diff --git a/parcels/tools/global_statics.py b/parcels/tools/global_statics.py index 0e97bac0d..896d3195f 100644 --- a/parcels/tools/global_statics.py +++ b/parcels/tools/global_statics.py @@ -8,7 +8,7 @@ from os import getuid except: # Windows does not have getuid(), so define to simply return 'tmp' - def getuid(): + def getuid(): # type: ignore return "tmp" diff --git a/parcels/tools/interpolation_utils.py b/parcels/tools/interpolation_utils.py index 27ab24c7f..273dbbec8 100644 --- a/parcels/tools/interpolation_utils.py +++ b/parcels/tools/interpolation_utils.py @@ -1,17 +1,19 @@ +from typing import Callable, Literal + import numpy as np -__all__ = [] +from parcels._typing import Mesh +__all__ = [] # type: ignore -# fmt: off -def phi1D_lin(xsi): - phi = [1-xsi, - xsi] +def phi1D_lin(xsi: float) -> list[float]: + phi = [1 - xsi, xsi] return phi -def phi1D_quad(xsi): +# fmt: off +def phi1D_quad(xsi: float) -> list[float]: phi = [2*xsi**2-3*xsi+1, -4*xsi**2+4*xsi, 2*xsi**2-xsi] @@ -19,7 +21,8 @@ def phi1D_quad(xsi): return phi -def phi2D_lin(xsi, eta): + +def phi2D_lin(xsi: float, eta: float) -> list[float]: phi = [(1-xsi) * (1-eta), xsi * (1-eta), xsi * eta , @@ -28,7 +31,7 @@ def phi2D_lin(xsi, eta): return phi -def phi3D_lin(xsi, eta, zet): +def phi3D_lin(xsi: float, eta: float, zet: float) -> list[float]: phi = [(1-xsi) * (1-eta) * (1-zet), xsi * (1-eta) * (1-zet), xsi * eta * (1-zet), @@ -41,7 +44,7 @@ def phi3D_lin(xsi, eta, zet): return phi -def dphidxsi3D_lin(xsi, eta, zet): +def dphidxsi3D_lin(xsi: float, eta: float, zet: float) -> tuple[list[float], list[float], list[float]]: dphidxsi = [ - (1-eta) * (1-zet), (1-eta) * (1-zet), ( eta) * (1-zet), @@ -70,7 +73,9 @@ def dphidxsi3D_lin(xsi, eta, zet): return dphidxsi, dphideta, dphidzet -def dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh): +def dxdxsi3D_lin( + hexa_x: list[float], hexa_y: list[float], hexa_z: list[float], xsi: float, eta: float, zet: float, mesh: Mesh +) -> tuple[float, float, float, float, float, float, float, float, float]: dphidxsi, dphideta, dphidzet = dphidxsi3D_lin(xsi, eta, zet) if mesh == 'spherical': @@ -99,16 +104,29 @@ def dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh): return dxdxsi, dxdeta, dxdzet, dydxsi, dydeta, dydzet, dzdxsi, dzdeta, dzdzet -def jacobian3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh): +def jacobian3D_lin( + hexa_x: list[float], hexa_y: list[float], hexa_z: list[float], xsi: float, eta: float, zet: float, mesh: Mesh +) -> float: dxdxsi, dxdeta, dxdzet, dydxsi, dydeta, dydzet, dzdxsi, dzdeta, dzdzet = dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh) - jac = dxdxsi * (dydeta*dzdzet - dzdeta*dydzet)\ - - dxdeta * (dydxsi*dzdzet - dzdxsi*dydzet)\ - + dxdzet * (dydxsi*dzdeta - dzdxsi*dydeta) + jac = ( + dxdxsi * (dydeta * dzdzet - dzdeta * dydzet) + - dxdeta * (dydxsi * dzdzet - dzdxsi * dydzet) + + dxdzet * (dydxsi * dzdeta - dzdxsi * dydeta) + ) return jac -def jacobian3D_lin_face(hexa_x, hexa_y, hexa_z, xsi, eta, zet, orientation, mesh): +def jacobian3D_lin_face( + hexa_x: list[float], + hexa_y: list[float], + hexa_z: list[float], + xsi: float, + eta: float, + zet: float, + orientation: Literal["zonal", "meridional", "vertical"], + mesh: Mesh, +) -> float: dxdxsi, dxdeta, dxdzet, dydxsi, dydeta, dydzet, dzdxsi, dzdeta, dzdzet = dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh) if orientation == 'zonal': @@ -128,7 +146,7 @@ def jacobian3D_lin_face(hexa_x, hexa_y, hexa_z, xsi, eta, zet, orientation, mesh return jac -def dphidxsi2D_lin(xsi, eta): +def dphidxsi2D_lin(xsi: float, eta: float) -> tuple[list[float], list[float]]: dphidxsi = [-(1-eta), 1-eta, eta, @@ -141,7 +159,12 @@ def dphidxsi2D_lin(xsi, eta): return dphidxsi, dphideta -def dxdxsi2D_lin(quad_x, quad_y, xsi, eta,): +def dxdxsi2D_lin( + quad_x, + quad_y, + xsi: float, + eta: float, +): dphidxsi, dphideta = dphidxsi2D_lin(xsi, eta) dxdxsi = np.dot(quad_x, dphidxsi) @@ -152,20 +175,21 @@ def dxdxsi2D_lin(quad_x, quad_y, xsi, eta,): return dxdxsi, dxdeta, dydxsi, dydeta -def jacobian2D_lin(quad_x, quad_y, xsi, eta): +def jacobian2D_lin(quad_x, quad_y, xsi: float, eta: float): dxdxsi, dxdeta, dydxsi, dydeta = dxdxsi2D_lin(quad_x, quad_y, xsi, eta) - jac = dxdxsi*dydeta - dxdeta*dydxsi + jac = dxdxsi * dydeta - dxdeta * dydxsi return jac def length2d_lin_edge(quad_x, quad_y, ids): xe = [quad_x[ids[0]], quad_x[ids[1]]] ye = [quad_y[ids[0]], quad_y[ids[1]]] - return np.sqrt((xe[1]-xe[0])**2+(ye[1]-ye[0])**2) + return np.sqrt((xe[1] - xe[0]) ** 2 + (ye[1] - ye[0]) ** 2) -def interpolate(phi, f, xsi): +def interpolate(phi: Callable[[float], list[float]], f: list[float], xsi: float) -> float: return np.dot(phi(xsi), f) + # fmt: on diff --git a/parcels/tools/loggers.py b/parcels/tools/loggers.py index b21dee81a..335bc3a8d 100644 --- a/parcels/tools/loggers.py +++ b/parcels/tools/loggers.py @@ -40,10 +40,10 @@ def info_once(self, message, *args, **kws): logger.addHandler(handler) logging.addLevelName(warning_once_level, "WARNING") -logging.Logger.warning_once = warning_once +logging.Logger.warning_once = warning_once # type: ignore logging.addLevelName(info_once_level, "INFO") -logging.Logger.info_once = info_once +logging.Logger.info_once = info_once # type: ignore dup_filter = DuplicateFilter() logger.addFilter(dup_filter) diff --git a/parcels/tools/timer.py b/parcels/tools/timer.py index 02daf75ab..6aa7701e3 100644 --- a/parcels/tools/timer.py +++ b/parcels/tools/timer.py @@ -6,7 +6,7 @@ except ModuleNotFoundError: MPI = None -__all__ = [] +__all__ = [] # type: ignore class Timer: diff --git a/pyproject.toml b/pyproject.toml index d42401fbe..613ff9978 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,3 +121,25 @@ ignore = [ [tool.ruff.lint.pydocstyle] convention = "numpy" + +[tool.mypy] +files = [ + "parcels/compilation/codegenerator.py", + "parcels/_typing.py", + "parcels/tools/*.py", + "parcels/grid.py", +] + +[[tool.mypy.overrides]] +module = [ + "parcels._version_setup", + "mpi4py", + "scipy.spatial", + "sklearn.cluster", + "zarr", + "cftime", + "pykdtree.kdtree", + "netCDF4", + "cgen" +] +ignore_missing_imports = true