Skip to content

Commit

Permalink
Merge #1129
Browse files Browse the repository at this point in the history
1129: Implement Dask collection interface r=hgrecco a=rpmanser

This pull request implements the [Dask collection interface](https://docs.dask.org/en/latest/custom-collections.html) for the Quantity class and adds convenience methods `compute()`, `persist()`, and `visualize()`. As is, only the convenience methods are wrapped to check that the magnitude of the Quantity is a dask array, which should cover current use cases. Ping @jthielen since we've been discussing this.

I have not tested `persist()` on a HPC cluster, which will have different behavior than when it is called on a single machine (e.g., desktop). Is there a way to integrate tests for distributed cases? If not, I can test this independently if needed.

- [X] Closes #883 
- [X] Executed ``black -t py36 . && isort -rc . && flake8`` with no errors
- [X] The change is fully covered by automated unit tests
- [X] Documented in docs/ as appropriate
- [X] Added an entry to the CHANGES file

Co-authored-by: Russell Manser <russell.p.manser@ttu.edu>
  • Loading branch information
bors[bot] and rpmanser authored Jul 8, 2020
2 parents 4ab68ba + e529f34 commit f1ddf49
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Pint Changelog
0.15 (unreleased)
-----------------

- Nothing changed yet.
- Implement Dask collection interface to support Pint Quantity wrapped Dask arrays.


0.14 (2020-07-01)
Expand Down
32 changes: 31 additions & 1 deletion docs/numpy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,36 @@
"print(repr(m * ureg.m))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Pint Quantity wrapping Dask Array**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import dask.array as da\n",
"\n",
"d = da.arange(500, chunks=50)\n",
"\n",
"# Must create using Quantity class, otherwise Dask will wrap Pint Quantity\n",
"q = ureg.Quantity(d, ureg.kelvin)\n",
"\n",
"print(repr(q))\n",
"print()\n",
"\n",
"# DO NOT create using multiplication on the right until\n",
"# https://github.com/dask/dask/issues/4583 is resolved, as\n",
"# unexpected behavior may result\n",
"print(repr(d * ureg.kelvin))\n",
"print(repr(ureg.kelvin * d))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -465,7 +495,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.8.2"
}
},
"nbformat": 4,
Expand Down
11 changes: 10 additions & 1 deletion pint/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class BehaviorChangeWarning(UserWarning):

try:
import numpy as np
from numpy import ndarray, datetime64 as np_datetime64
from numpy import datetime64 as np_datetime64
from numpy import ndarray

HAS_NUMPY = True
NUMPY_VER = np.__version__
Expand Down Expand Up @@ -157,6 +158,14 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False):
except ImportError:
pass

try:
from dask import array as dask_array
from dask.base import compute, persist, visualize

except ImportError:
compute, persist, visualize = None, None, None
dask_array = None


def is_upcast_type(other) -> bool:
"""Check if the type object is a upcast type using preset list.
Expand Down
1 change: 1 addition & 0 deletions pint/pint-convert
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ ureg.default_system = args.system

if args.unc:
import uncertainties

# Measured constans subject to correlation
# R_i: Rydberg constant
# g_e: Electron g factor
Expand Down
100 changes: 100 additions & 0 deletions pint/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@
NUMPY_VER,
_to_magnitude,
babel_parse,
compute,
dask_array,
eq,
is_duck_array_type,
is_upcast_type,
ndarray,
np,
persist,
visualize,
zero_or_nan,
)
from .definitions import UnitDefinition
Expand Down Expand Up @@ -125,6 +129,20 @@ def wrapper(func):
return wrapper


def check_dask_array(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
if isinstance(self._magnitude, dask_array.Array):
return f(self, *args, **kwargs)
else:
msg = "Method {} only implemented for objects of {}, not {}".format(
f.__name__, dask_array.Array, self._magnitude.__class__
)
raise AttributeError(msg)

return wrapper


@contextlib.contextmanager
def printoptions(*args, **kwargs):
"""Numpy printoptions context manager released with version 1.15.0
Expand Down Expand Up @@ -1900,6 +1918,88 @@ def _ok_for_muldiv(self, no_offset_units=None):
def to_timedelta(self):
return datetime.timedelta(microseconds=self.to("microseconds").magnitude)

# Dask.array.Array ducking
def __dask_graph__(self):
if isinstance(self._magnitude, dask_array.Array):
return self._magnitude.__dask_graph__()
else:
return None

def __dask_keys__(self):
return self._magnitude.__dask_keys__()

@property
def __dask_optimize__(self):
return dask_array.Array.__dask_optimize__

@property
def __dask_scheduler__(self):
return dask_array.Array.__dask_scheduler__

def __dask_postcompute__(self):
func, args = self._magnitude.__dask_postcompute__()
return self._dask_finalize, (func, args, self.units)

def __dask_postpersist__(self):
func, args = self._magnitude.__dask_postpersist__()
return self._dask_finalize, (func, args, self.units)

@staticmethod
def _dask_finalize(results, func, args, units):
values = func(results, *args)
return Quantity(values, units)

@check_dask_array
def compute(self, **kwargs):
"""Compute the Dask array wrapped by pint.Quantity.
Parameters
----------
**kwargs : dict
Any keyword arguments to pass to ``dask.compute``.
Returns
-------
pint.Quantity
A pint.Quantity wrapped numpy array.
"""
(result,) = compute(self, **kwargs)
return result

@check_dask_array
def persist(self, **kwargs):
"""Persist the Dask Array wrapped by pint.Quantity.
Parameters
----------
**kwargs : dict
Any keyword arguments to pass to ``dask.persist``.
Returns
-------
pint.Quantity
A pint.Quantity wrapped Dask array.
"""
(result,) = persist(self, **kwargs)
return result

@check_dask_array
def visualize(self, **kwargs):
"""Produce a visual representation of the Dask graph.
The graphviz library is required.
Parameters
----------
**kwargs : dict
Any keyword arguments to pass to the ``dask.base.visualize`` function.
Returns
-------
"""
visualize(self, **kwargs)


_Quantity = Quantity

Expand Down
9 changes: 6 additions & 3 deletions pint/testsuite/test_compat_downcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Conditionally import NumPy and any upcast type libraries
np = pytest.importorskip("numpy", reason="NumPy is not available")
sparse = pytest.importorskip("sparse", reason="sparse is not available")
da = pytest.importorskip("dask.array", reason="Dask is not available")

# Set up unit registry and sample
ureg = UnitRegistry(force_ndarray_like=True)
Expand All @@ -16,7 +17,7 @@ def identity(x):
return x


@pytest.fixture(params=["sparse", "masked_array"])
@pytest.fixture(params=["sparse", "masked_array", "dask_array"])
def array(request):
"""Generate 5x5 arrays of given type for tests."""
if request.param == "sparse":
Expand All @@ -30,6 +31,8 @@ def array(request):
np.arange(25, dtype=np.float).reshape((5, 5)),
mask=np.logical_not(np.triu(np.ones((5, 5)))),
)
elif request.param == "dask_array":
return da.arange(25, chunks=5, dtype=float).reshape((5, 5))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -57,8 +60,8 @@ def array(request):
pytest.param(np.sum, np.sum, identity, id="sum ufunc"),
pytest.param(np.sqrt, np.sqrt, lambda u: u ** 0.5, id="sqrt ufunc"),
pytest.param(
lambda x: np.reshape(x, 25),
lambda x: np.reshape(x, 25),
lambda x: np.reshape(x, (25,)),
lambda x: np.reshape(x, (25,)),
identity,
id="reshape function",
),
Expand Down
Loading

0 comments on commit f1ddf49

Please sign in to comment.