Skip to content

Commit

Permalink
Multiple outputs (#419)
Browse files Browse the repository at this point in the history
* Multiple outputs

* Get general_blockwise core op working for multiple outputs

* Handle fusion of multiple outputs

Test for child fusion (where multiple outputs op is fused with its two children)

Test for child fusion (where multiple outputs op is fused with its two children)

test_fuse_multiple_outputs_diamond

sibling fusion test

* Mem utilization test for multiple outputs

* Allow multiple output functions to just return a tuple

* Fix for Zarr v3
  • Loading branch information
tomwhite authored Sep 16, 2024
1 parent bafa0b3 commit 3dbf6d6
Show file tree
Hide file tree
Showing 10 changed files with 448 additions and 118 deletions.
12 changes: 6 additions & 6 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,9 @@ def key_function(out_key):
key_function,
x,
template,
shape=shape,
dtype=x.dtype,
chunks=outchunks,
shapes=[shape],
dtypes=[x.dtype],
chunkss=[outchunks],
)


Expand Down Expand Up @@ -402,9 +402,9 @@ def key_function(out_key):
_read_stack_chunk,
key_function,
*arrays,
shape=shape,
dtype=dtype,
chunks=chunks,
shapes=[shape],
dtypes=[dtype],
chunkss=[chunks],
axis=axis,
fusable=False,
)
Expand Down
60 changes: 36 additions & 24 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from itertools import product
from numbers import Integral, Number
from operator import add
from typing import TYPE_CHECKING, Any, Sequence, Union
from typing import TYPE_CHECKING, Any, Sequence, Tuple, Union
from warnings import warn

import ndindex
Expand Down Expand Up @@ -333,14 +333,14 @@ def general_blockwise(
func,
key_function,
*arrays,
shape,
dtype,
chunks,
target_store=None,
target_path=None,
shapes,
dtypes,
chunkss,
target_stores=None,
target_paths=None,
extra_func_kwargs=None,
**kwargs,
) -> "Array":
) -> Union["Array", Tuple["Array", ...]]:
assert len(arrays) > 0

# replace arrays with zarr arrays
Expand All @@ -354,24 +354,33 @@ def general_blockwise(

num_input_blocks = kwargs.pop("num_input_blocks", None)

name = gensym()
spec = check_array_specs(arrays)
if target_store is None:
target_store = new_temp_path(name=name, spec=spec)

if isinstance(target_stores, list): # multiple outputs
name = [gensym() for _ in range(len(target_stores))]
target_stores = [
ts if ts is not None else new_temp_path(name=n, spec=spec)
for n, ts in zip(name, target_stores)
]
else: # single output
name = gensym()
if target_stores is None:
target_stores = [new_temp_path(name=name, spec=spec)]

op = primitive_general_blockwise(
func,
key_function,
*zargs,
allowed_mem=spec.allowed_mem,
reserved_mem=spec.reserved_mem,
extra_projected_mem=extra_projected_mem,
target_store=target_store,
target_path=target_path,
target_stores=target_stores,
target_paths=target_paths,
storage_options=spec.storage_options,
compressor=spec.zarr_compressor,
shape=shape,
dtype=dtype,
chunks=chunks,
shapes=shapes,
dtypes=dtypes,
chunkss=chunkss,
in_names=in_names,
extra_func_kwargs=extra_func_kwargs,
num_input_blocks=num_input_blocks,
Expand All @@ -387,7 +396,10 @@ def general_blockwise(
)
from cubed.array_api import Array

return Array(name, op.target_array, spec, plan)
if isinstance(op.target_array, list): # multiple outputs
return tuple(Array(n, ta, spec, plan) for n, ta in zip(name, op.target_array))
else: # single output
return Array(name, op.target_array, spec, plan)


def elemwise(func, *args: "Array", dtype=None) -> "Array":
Expand Down Expand Up @@ -914,9 +926,9 @@ def key_function(out_key):
_concatenate2,
key_function,
x,
shape=x.shape,
dtype=x.dtype,
chunks=target_chunks,
shapes=[x.shape],
dtypes=[x.dtype],
chunkss=[target_chunks],
extra_projected_mem=0,
num_input_blocks=(num_input_blocks,),
axes=axes,
Expand Down Expand Up @@ -1229,12 +1241,12 @@ def partial_reduce(
axis = tuple(ax for ax in split_every.keys())
combine_sizes = combine_sizes or {}
combine_sizes = {k: combine_sizes.get(k, 1) for k in axis}
chunks = [
chunks = tuple(
(combine_sizes[i],) * math.ceil(len(c) / split_every[i])
if i in split_every
else c
for (i, c) in enumerate(x.chunks)
]
)
shape = tuple(map(sum, chunks))

def key_function(out_key):
Expand Down Expand Up @@ -1263,9 +1275,9 @@ def key_function(out_key):
_partial_reduce,
key_function,
x,
shape=shape,
dtype=dtype,
chunks=chunks,
shapes=[shape],
dtypes=[dtype],
chunkss=[chunks],
extra_projected_mem=extra_projected_mem,
num_input_blocks=(sum(split_every.values()),),
reduce_func=func,
Expand Down
21 changes: 19 additions & 2 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def can_fuse(n):
if "primitive_op" not in nodes[op2]:
return False

# if node (op2) does not have exactly one input then don't fuse
# if node (op2) does not have exactly one input and output then don't fuse
# (it could have no inputs or multiple inputs)
if dag.in_degree(op2) != 1:
if dag.in_degree(op2) != 1 or dag.out_degree(op2) != 1:
return False

# if input is one of the arrays being computed then don't fuse
Expand Down Expand Up @@ -91,6 +91,12 @@ def predecessors_unordered(dag, name):
yield pre


def successors_unordered(dag, name):
"""Return a node's successors in no particular order, with repeats for multiple edges."""
for pre, _ in dag.out_edges(name):
yield pre


def predecessor_ops(dag, name):
"""Return an op node's op predecessors in the same order as the input source arrays for the op.
Expand Down Expand Up @@ -183,6 +189,17 @@ def can_fuse_predecessors(
)
return False

# if any predecessor ops have multiple outputs then don't fuse
# TODO: implement "child fusion" (where a multiple output op fuses its children)
if any(
len(list(successors_unordered(dag, pre))) > 1
for pre in predecessor_ops(dag, name)
):
logger.debug(
"can't fuse %s since at least one predecessor has multiple outputs", name
)
return False

# if node is in never_fuse or always_fuse list then it overrides logic below
if never_fuse is not None and name in never_fuse:
logger.debug("can't fuse %s since it is in 'never_fuse'", name)
Expand Down
60 changes: 41 additions & 19 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Plan:
def __init__(self, dag):
self.dag = dag

# args from pipeline onwards are omitted for creation functions when no computation is needed
# args from primitive_op onwards are omitted for creation functions when no computation is needed
@classmethod
def _new(
cls,
Expand Down Expand Up @@ -110,15 +110,26 @@ def _new(
op_display_name=f"{op_name_unique}\n{first_cubed_summary.name}",
hidden=hidden,
)
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
# array
if isinstance(name, list): # multiple outputs
for n, t in zip(name, target):
dag.add_node(
n,
name=n,
type="array",
target=t,
hidden=hidden,
)
dag.add_edge(op_name_unique, n)
else: # single output
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
else:
# op
dag.add_node(
Expand All @@ -132,15 +143,26 @@ def _new(
primitive_op=primitive_op,
pipeline=primitive_op.pipeline,
)
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
# array
if isinstance(name, list): # multiple outputs
for n, t in zip(name, target):
dag.add_node(
n,
name=n,
type="array",
target=t,
hidden=hidden,
)
dag.add_edge(op_name_unique, n)
else: # single output
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
for x in source_arrays:
if hasattr(x, "name"):
dag.add_edge(x.name, op_name_unique)
Expand Down
Loading

0 comments on commit 3dbf6d6

Please sign in to comment.