dags provides tools to combine several interrelated functions into one function. The order in which the functions are called is determined by a topological sort on a dag that is constructed from the function signatures. You can specify which of the function results will be returned in the combined function.
dags is a tiny library, all the hard work is done by the great NetworkX
To understand what dags does, let's look at a very simple example of a few functions that do simple calculations.
def f(x, y):
return x**2 + y**2
def g(y, z):
return 0.5 * y * z
def h(f, g):
return g / f
Assume that we are interested in a function that calculates h, given x, y and z.
We could hardcode this function as:
def hardcoded_combined(x, y, z):
_f = f(x, y)
_g = g(y, z)
return h(_f, _g)
hardcoded_combined(x=1, y=2, z=3)
0.6
Instead, we can use dags to construct the same function:
from dags import concatenate_functions
combined = concatenate_functions([h, f, g], targets="h")
combined(x=1, y=2, z=3)
0.6
More examples can be found in the documentation
- The dag is constructed while the combined function is created and does not cause too much overhead when the function is called.
- If all individual functions are jax compatible, the combined function is jax compatible.
- When jitted or vmapped with jax, we havenot seen any performance loss compared to hard coding the combined function.
- Whene there is more than one target, you can determine whether the result is returned as tuple, list or dict or pass in an aggregator to combine the multiple outputs.
- Since the relationships are discoverd from function signatures, dags provides decorators to rename arguments.
dags is available on PyPI and Anaconda.org. Install it with
$ pip install dags
# or
$ conda install -c conda-forge dags
The documentation is hosted on Read the Docs.