From 3dbf6d6df7d430086248d3d50f097e6b0b9a84e8 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 16 Sep 2024 13:06:39 +0100 Subject: [PATCH] Multiple outputs (#419) * Multiple outputs * Get general_blockwise core op working for multiple outputs * Handle fusion of multiple outputs Test for child fusion (where multiple outputs op is fused with its two children) Test for child fusion (where multiple outputs op is fused with its two children) test_fuse_multiple_outputs_diamond sibling fusion test * Mem utilization test for multiple outputs * Allow multiple output functions to just return a tuple * Fix for Zarr v3 --- cubed/array_api/manipulation_functions.py | 12 +- cubed/core/ops.py | 60 +++++---- cubed/core/optimization.py | 21 +++- cubed/core/plan.py | 60 ++++++--- cubed/primitive/blockwise.py | 145 ++++++++++++++-------- cubed/primitive/types.py | 2 +- cubed/tests/primitive/test_blockwise.py | 70 ++++++++++- cubed/tests/test_core.py | 36 +++++- cubed/tests/test_mem_utilization.py | 33 +++-- cubed/tests/test_optimization.py | 127 +++++++++++++++++++ 10 files changed, 448 insertions(+), 118 deletions(-) diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 350f924b..d226dd44 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -321,9 +321,9 @@ def key_function(out_key): key_function, x, template, - shape=shape, - dtype=x.dtype, - chunks=outchunks, + shapes=[shape], + dtypes=[x.dtype], + chunkss=[outchunks], ) @@ -402,9 +402,9 @@ def key_function(out_key): _read_stack_chunk, key_function, *arrays, - shape=shape, - dtype=dtype, - chunks=chunks, + shapes=[shape], + dtypes=[dtype], + chunkss=[chunks], axis=axis, fusable=False, ) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index bea16dcb..f6549e43 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -5,7 +5,7 @@ from itertools import product from numbers import Integral, Number from operator import add -from typing import TYPE_CHECKING, Any, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence, Tuple, Union from warnings import warn import ndindex @@ -333,14 +333,14 @@ def general_blockwise( func, key_function, *arrays, - shape, - dtype, - chunks, - target_store=None, - target_path=None, + shapes, + dtypes, + chunkss, + target_stores=None, + target_paths=None, extra_func_kwargs=None, **kwargs, -) -> "Array": +) -> Union["Array", Tuple["Array", ...]]: assert len(arrays) > 0 # replace arrays with zarr arrays @@ -354,10 +354,19 @@ def general_blockwise( num_input_blocks = kwargs.pop("num_input_blocks", None) - name = gensym() spec = check_array_specs(arrays) - if target_store is None: - target_store = new_temp_path(name=name, spec=spec) + + if isinstance(target_stores, list): # multiple outputs + name = [gensym() for _ in range(len(target_stores))] + target_stores = [ + ts if ts is not None else new_temp_path(name=n, spec=spec) + for n, ts in zip(name, target_stores) + ] + else: # single output + name = gensym() + if target_stores is None: + target_stores = [new_temp_path(name=name, spec=spec)] + op = primitive_general_blockwise( func, key_function, @@ -365,13 +374,13 @@ def general_blockwise( allowed_mem=spec.allowed_mem, reserved_mem=spec.reserved_mem, extra_projected_mem=extra_projected_mem, - target_store=target_store, - target_path=target_path, + target_stores=target_stores, + target_paths=target_paths, storage_options=spec.storage_options, compressor=spec.zarr_compressor, - shape=shape, - dtype=dtype, - chunks=chunks, + shapes=shapes, + dtypes=dtypes, + chunkss=chunkss, in_names=in_names, extra_func_kwargs=extra_func_kwargs, num_input_blocks=num_input_blocks, @@ -387,7 +396,10 @@ def general_blockwise( ) from cubed.array_api import Array - return Array(name, op.target_array, spec, plan) + if isinstance(op.target_array, list): # multiple outputs + return tuple(Array(n, ta, spec, plan) for n, ta in zip(name, op.target_array)) + else: # single output + return Array(name, op.target_array, spec, plan) def elemwise(func, *args: "Array", dtype=None) -> "Array": @@ -914,9 +926,9 @@ def key_function(out_key): _concatenate2, key_function, x, - shape=x.shape, - dtype=x.dtype, - chunks=target_chunks, + shapes=[x.shape], + dtypes=[x.dtype], + chunkss=[target_chunks], extra_projected_mem=0, num_input_blocks=(num_input_blocks,), axes=axes, @@ -1229,12 +1241,12 @@ def partial_reduce( axis = tuple(ax for ax in split_every.keys()) combine_sizes = combine_sizes or {} combine_sizes = {k: combine_sizes.get(k, 1) for k in axis} - chunks = [ + chunks = tuple( (combine_sizes[i],) * math.ceil(len(c) / split_every[i]) if i in split_every else c for (i, c) in enumerate(x.chunks) - ] + ) shape = tuple(map(sum, chunks)) def key_function(out_key): @@ -1263,9 +1275,9 @@ def key_function(out_key): _partial_reduce, key_function, x, - shape=shape, - dtype=dtype, - chunks=chunks, + shapes=[shape], + dtypes=[dtype], + chunkss=[chunks], extra_projected_mem=extra_projected_mem, num_input_blocks=(sum(split_every.values()),), reduce_func=func, diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index 883a257a..86f3b4b5 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -31,9 +31,9 @@ def can_fuse(n): if "primitive_op" not in nodes[op2]: return False - # if node (op2) does not have exactly one input then don't fuse + # if node (op2) does not have exactly one input and output then don't fuse # (it could have no inputs or multiple inputs) - if dag.in_degree(op2) != 1: + if dag.in_degree(op2) != 1 or dag.out_degree(op2) != 1: return False # if input is one of the arrays being computed then don't fuse @@ -91,6 +91,12 @@ def predecessors_unordered(dag, name): yield pre +def successors_unordered(dag, name): + """Return a node's successors in no particular order, with repeats for multiple edges.""" + for pre, _ in dag.out_edges(name): + yield pre + + def predecessor_ops(dag, name): """Return an op node's op predecessors in the same order as the input source arrays for the op. @@ -183,6 +189,17 @@ def can_fuse_predecessors( ) return False + # if any predecessor ops have multiple outputs then don't fuse + # TODO: implement "child fusion" (where a multiple output op fuses its children) + if any( + len(list(successors_unordered(dag, pre))) > 1 + for pre in predecessor_ops(dag, name) + ): + logger.debug( + "can't fuse %s since at least one predecessor has multiple outputs", name + ) + return False + # if node is in never_fuse or always_fuse list then it overrides logic below if never_fuse is not None and name in never_fuse: logger.debug("can't fuse %s since it is in 'never_fuse'", name) diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 7954d8ca..1c4506bf 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -73,7 +73,7 @@ class Plan: def __init__(self, dag): self.dag = dag - # args from pipeline onwards are omitted for creation functions when no computation is needed + # args from primitive_op onwards are omitted for creation functions when no computation is needed @classmethod def _new( cls, @@ -110,15 +110,26 @@ def _new( op_display_name=f"{op_name_unique}\n{first_cubed_summary.name}", hidden=hidden, ) - # array (when multiple outputs are supported there could be more than one) - dag.add_node( - name, - name=name, - type="array", - target=target, - hidden=hidden, - ) - dag.add_edge(op_name_unique, name) + # array + if isinstance(name, list): # multiple outputs + for n, t in zip(name, target): + dag.add_node( + n, + name=n, + type="array", + target=t, + hidden=hidden, + ) + dag.add_edge(op_name_unique, n) + else: # single output + dag.add_node( + name, + name=name, + type="array", + target=target, + hidden=hidden, + ) + dag.add_edge(op_name_unique, name) else: # op dag.add_node( @@ -132,15 +143,26 @@ def _new( primitive_op=primitive_op, pipeline=primitive_op.pipeline, ) - # array (when multiple outputs are supported there could be more than one) - dag.add_node( - name, - name=name, - type="array", - target=target, - hidden=hidden, - ) - dag.add_edge(op_name_unique, name) + # array + if isinstance(name, list): # multiple outputs + for n, t in zip(name, target): + dag.add_node( + n, + name=n, + type="array", + target=t, + hidden=hidden, + ) + dag.add_edge(op_name_unique, n) + else: # single output + dag.add_node( + name, + name=name, + type="array", + target=target, + hidden=hidden, + ) + dag.add_edge(op_name_unique, name) for x in source_arrays: if hasattr(x, "name"): dag.add_edge(x.name, op_name_unique) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index b4571149..d36b5fbb 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -1,3 +1,4 @@ +import inspect import itertools import logging import math @@ -61,8 +62,8 @@ class BlockwiseSpec: The number of input blocks read from each input array. reads_map : Dict[str, CubedArrayProxy] Read proxy dictionary keyed by array name. - write : CubedArrayProxy - Write proxy with an ``array`` attribute that supports ``__setitem__``. + writes_list : List[CubedArrayProxy] + Write proxy list where entries have an ``array`` attribute that supports ``__setitem__``. """ key_function: Callable[..., Any] @@ -70,16 +71,13 @@ class BlockwiseSpec: function_nargs: int num_input_blocks: Tuple[int, ...] reads_map: Dict[str, CubedArrayProxy] - write: CubedArrayProxy + writes_list: List[CubedArrayProxy] def apply_blockwise(out_coords: List[int], *, config: BlockwiseSpec) -> None: """Stage function for blockwise.""" # lithops needs params to be lists not tuples, so convert back out_coords_tuple = tuple(out_coords) - out_chunk_key = key_to_slices( - out_coords_tuple, config.write.array, config.write.chunks - ) # get array chunks for input keys, preserving any nested list structure args = [] @@ -90,14 +88,25 @@ def apply_blockwise(out_coords: List[int], *, config: BlockwiseSpec) -> None: arg = map_nested(get_chunk_config, in_key) args.append(arg) - result = config.function(*args) - if isinstance(result, dict): # structured array with named fields - for k, v in result.items(): - v = backend_array_to_numpy_array(v) - config.write.open().set_basic_selection(out_chunk_key, v, fields=k) - else: - result = backend_array_to_numpy_array(result) - config.write.open()[out_chunk_key] = result + results = config.function(*args) + # if blockwise function is a regular function (not a generator) that doesn't return multiple values then make it iterable + if not inspect.isgeneratorfunction(config.function) and not isinstance( + results, tuple + ): + results = (results,) + for i, result in enumerate(results): + out_chunk_key = key_to_slices( + out_coords_tuple, config.writes_list[i].array, config.writes_list[i].chunks + ) + if isinstance(result, dict): # structured array with named fields + for k, v in result.items(): + v = backend_array_to_numpy_array(v) + config.writes_list[i].open().set_basic_selection( + out_chunk_key, v, fields=k + ) + else: + result = backend_array_to_numpy_array(result) + config.writes_list[i].open()[out_chunk_key] = result def key_to_slices( @@ -140,7 +149,7 @@ def blockwise( fusable: bool = True, num_input_blocks: Optional[Tuple[int, ...]] = None, **kwargs, -): +) -> PrimitiveOperation: """Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules. Unlike ```general_blockwise``, an index notation is used to specify the block mapping, @@ -213,13 +222,13 @@ def blockwise( *arrays, allowed_mem=allowed_mem, reserved_mem=reserved_mem, - target_store=target_store, - target_path=target_path, + target_stores=[target_store], + target_paths=[target_path] if target_path is not None else None, storage_options=storage_options, compressor=compressor, - shape=shape, - dtype=dtype, - chunks=chunks, + shapes=[shape], + dtypes=[dtype], + chunkss=[chunks], in_names=in_names, extra_projected_mem=extra_projected_mem, extra_func_kwargs=extra_func_kwargs, @@ -235,22 +244,22 @@ def general_blockwise( *arrays: Any, allowed_mem: int, reserved_mem: int, - target_store: T_Store, - target_path: Optional[str] = None, + target_stores: List[T_Store], + target_paths: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None, compressor: Union[dict, str, None] = "default", - shape: T_Shape, - dtype: T_DType, - chunks: T_Chunks, + shapes: List[T_Shape], + dtypes: List[T_DType], + chunkss: List[T_Chunks], in_names: Optional[List[str]] = None, extra_projected_mem: int = 0, extra_func_kwargs: Optional[Dict[str, Any]] = None, fusable: bool = True, num_input_blocks: Optional[Tuple[int, ...]] = None, **kwargs, -): +) -> PrimitiveOperation: """A more general form of ``blockwise`` that uses a function to specify the block - mapping, rather than an index notation. + mapping, rather than an index notation, and which supports multiple outputs. Parameters ---------- @@ -288,35 +297,51 @@ def general_blockwise( array_names = in_names or [f"in_{i}" for i in range(len(arrays))] array_map = {name: array for name, array in zip(array_names, arrays)} - chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) - chunksize = to_chunksize(chunks) - if isinstance(target_store, zarr.Array): - target_array = target_store - else: - target_array = lazy_zarr_array( - target_store, - shape, - dtype, - chunks=chunksize, - path=target_path, - storage_options=storage_options, - compressor=compressor, - ) - func_kwargs = extra_func_kwargs or {} func_with_kwargs = partial(func, **{**kwargs, **func_kwargs}) num_input_blocks = num_input_blocks or (1,) * len(arrays) read_proxies = { name: CubedArrayProxy(array, array.chunks) for name, array in array_map.items() } - write_proxy = CubedArrayProxy(target_array, chunksize) + + write_proxies = [] + output_chunk_memory = 0 + target_array = [] + + for i, target_store in enumerate(target_stores): + chunks_normal = normalize_chunks(chunkss[i], shape=shapes[i], dtype=dtypes[i]) + chunksize = to_chunksize(chunks_normal) + if isinstance(target_store, zarr.Array): + ta = target_store + else: + ta = lazy_zarr_array( + target_store, + shapes[i], + dtype=dtypes[i], + chunks=chunksize, + path=target_paths[i] if target_paths is not None else None, + storage_options=storage_options, + compressor=compressor, + ) + target_array.append(ta) + + write_proxies.append(CubedArrayProxy(ta, chunksize)) + + # only one output chunk is read into memory at a time, so we find the largest + output_chunk_memory = max( + output_chunk_memory, array_memory(dtypes[i], chunksize) * 2 + ) + + if len(target_array) == 1: + target_array = target_array[0] + spec = BlockwiseSpec( key_function, func_with_kwargs, len(arrays), num_input_blocks, read_proxies, - write_proxy, + write_proxies, ) # calculate projected memory @@ -332,7 +357,7 @@ def general_blockwise( # memory for a compressed and an uncompressed output array chunk # - this assumes the blockwise function creates a new array) # - numcodecs uses a working output buffer that's the size of the array being compressed - projected_mem += array_memory(dtype, chunksize) * 2 + projected_mem += output_chunk_memory if projected_mem > allowed_mem: raise ValueError( @@ -340,8 +365,10 @@ def general_blockwise( ) # this must be an iterator of lists, not of tuples, otherwise lithops breaks - output_blocks = map(list, itertools.product(*[range(len(c)) for c in chunks])) - num_tasks = math.prod(len(c) for c in chunks) + output_blocks = map( + list, itertools.product(*[range(len(c)) for c in chunks_normal]) + ) + num_tasks = math.prod(len(c) for c in chunks_normal) pipeline = CubedPipeline( apply_blockwise, @@ -488,7 +515,7 @@ def fused_func(*args): function_nargs = pipeline1.config.function_nargs read_proxies = pipeline1.config.reads_map - write_proxy = pipeline2.config.write + write_proxies = pipeline2.config.writes_list num_input_blocks = tuple( n * pipeline2.config.num_input_blocks[0] for n in pipeline1.config.num_input_blocks @@ -499,7 +526,7 @@ def fused_func(*args): function_nargs, num_input_blocks, read_proxies, - write_proxy, + write_proxies, ) source_array_names = primitive_op1.source_array_names @@ -595,7 +622,7 @@ def apply_pipeline_func(pipeline, n_input_blocks, *args): ret = map(lambda item: pipeline.config.function(*item), zip(*args)) return ret - def fused_func(*args): + def fused_func_single(*args): # args are grouped appropriately so they can be called by each predecessor function func_args = [ apply_pipeline_func(p, pipeline.config.num_input_blocks[i], *a) @@ -603,6 +630,20 @@ def fused_func(*args): ] return pipeline.config.function(*func_args) + # multiple outputs + def fused_func_generator(*args): + # args are grouped appropriately so they can be called by each predecessor function + func_args = [ + apply_pipeline_func(p, pipeline.config.num_input_blocks[i], *a) + for i, (p, a) in enumerate(zip(predecessor_pipelines, args)) + ] + yield from pipeline.config.function(*func_args) + + fused_func = ( + fused_func_generator + if inspect.isgeneratorfunction(pipeline.config.function) + else fused_func_single + ) fused_function_nargs = pipeline.config.function_nargs # ok to get num_input_blocks[0] since it is uniform (see check in can_fuse_multiple_primitive_ops) fused_num_input_blocks = tuple( @@ -618,14 +659,14 @@ def fused_func(*args): for p in predecessor_pipelines: if p is not None: read_proxies.update(p.config.reads_map) - write_proxy = pipeline.config.write + write_proxies = pipeline.config.writes_list spec = BlockwiseSpec( fused_key_func, fused_func, fused_function_nargs, fused_num_input_blocks, read_proxies, - write_proxy, + write_proxies, ) source_array_names = [] diff --git a/cubed/primitive/types.py b/cubed/primitive/types.py index 508a4c8b..ea0a982b 100644 --- a/cubed/primitive/types.py +++ b/cubed/primitive/types.py @@ -19,7 +19,7 @@ class PrimitiveOperation: """The names of the arrays which are inputs to this operation.""" target_array: Any - """The array being computed by this operation.""" + """The array or arrays being computed by this operation.""" projected_mem: int """An upper bound of the memory needed to run a task, in bytes.""" diff --git a/cubed/tests/primitive/test_blockwise.py b/cubed/tests/primitive/test_blockwise.py index c8a01b8e..a5886d6e 100644 --- a/cubed/tests/primitive/test_blockwise.py +++ b/cubed/tests/primitive/test_blockwise.py @@ -1,6 +1,5 @@ import numpy as np import pytest -import zarr from numpy.testing import assert_array_equal from cubed.backend_array_api import namespace as nxp @@ -204,10 +203,10 @@ def key_function(out_key): source, allowed_mem=allowed_mem, reserved_mem=0, - target_store=target_store, - shape=(20,), - dtype=int, - chunks=(6,), + target_stores=[target_store], + shapes=[(20,)], + dtypes=[int], + chunkss=[(6,)], in_names=[in_name], ) @@ -225,6 +224,67 @@ def key_function(out_key): assert_array_equal(res[:], np.arange(20)) +def test_blockwise_multiple_outputs(tmp_path, executor): + source = create_zarr( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=int, + chunks=(2, 2), + store=tmp_path / "source.zarr", + ) + allowed_mem = 1000 + target_store1 = tmp_path / "target1.zarr" + target_store2 = tmp_path / "target2.zarr" + + in_name = "x" + + def sqrts(x): + yield np.sqrt(x) + yield -np.sqrt(x) + + def block_function(out_key): + out_coords = out_key[1:] + return ((in_name, *out_coords),) + + op = general_blockwise( + sqrts, + block_function, + source, + allowed_mem=allowed_mem, + reserved_mem=0, + target_stores=[target_store1, target_store2], + shapes=[(3, 3), (3, 3)], + dtypes=[float, float], + chunkss=[(2, 2), (2, 2)], + in_names=[in_name], + ) + + assert isinstance(op.target_array, list) + assert len(op.target_array) == 2 + + assert op.target_array[0].shape == (3, 3) + assert op.target_array[0].dtype == float + assert op.target_array[0].chunks == (2, 2) + + assert op.target_array[1].shape == (3, 3) + assert op.target_array[1].dtype == float + assert op.target_array[1].chunks == (2, 2) + + assert op.num_tasks == 4 + + op.target_array[0].create() # create lazy zarr array + op.target_array[1].create() # create lazy zarr array + + execute_pipeline(op.pipeline, executor=executor) + + input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + res1 = open_backend_array(target_store1, mode="r") + assert_array_equal(res1[:], np.sqrt(input)) + + res2 = open_backend_array(target_store2, mode="r") + assert_array_equal(res2[:], -np.sqrt(input)) + + def test_make_blockwise_key_function_map(): func = lambda x: 0 diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 93400bfb..034d5126 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -11,8 +11,9 @@ import cubed import cubed.array_api as xp import cubed.random +from cubed.array_api.dtypes import _floating_dtypes from cubed.backend_array_api import namespace as nxp -from cubed.core.ops import merge_chunks, partial_reduce, tree_reduce +from cubed.core.ops import general_blockwise, merge_chunks, partial_reduce, tree_reduce from cubed.core.optimization import fuse_all_optimize_dag, multiple_inputs_optimize_dag from cubed.storage.backend import open_backend_array from cubed.tests.utils import ( @@ -676,3 +677,36 @@ def test_quad_means_zarr(tmp_path, t_length=50): m.visualize(filename=tmp_path / "quad_means", optimize_function=opt_fn) cubed.to_zarr(m, store=tmp_path / "result", optimize_function=opt_fn) + + +def sqrts(x): + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in sqrts") + + def _sqrts(x): + yield nxp.sqrt(x) + yield -nxp.sqrt(x) + + def block_function(out_key): + return ((x.name,) + out_key[1:],) + + return general_blockwise( + _sqrts, + block_function, + x, + shapes=[x.shape, x.shape], + dtypes=[x.dtype, x.dtype], + chunkss=[x.chunks, x.chunks], + target_stores=[None, None], # filled in by general_blockwise + ) + + +def test_multiple_outputs(): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), dtype=float) + b, c = sqrts(a) + + cubed.compute(b, c) + + input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert_array_equal(b, np.sqrt(input)) + assert_array_equal(c, -np.sqrt(input)) diff --git a/cubed/tests/test_mem_utilization.py b/cubed/tests/test_mem_utilization.py index edccca50..6bb1e650 100644 --- a/cubed/tests/test_mem_utilization.py +++ b/cubed/tests/test_mem_utilization.py @@ -19,6 +19,7 @@ from cubed.diagnostics.mem_warn import MemoryWarningCallback from cubed.diagnostics.memray import MemrayCallback from cubed.runtime.create import create_executor +from cubed.tests.test_core import sqrts from cubed.tests.utils import LITHOPS_LOCAL_CONFIG pd.set_option("display.max_columns", None) @@ -320,6 +321,19 @@ def test_sum_partial_reduce(tmp_path, spec, executor): run_operation(tmp_path, executor, "sum_partial_reduce", b) +# Multiple outputs + + +@pytest.mark.slow +def test_sqrts(tmp_path, spec, executor): + a = cubed.random.random( + (10000, 10000), chunks=(5000, 5000), spec=spec + ) # 200MB chunks + b, c = sqrts(a) + # don't optimize graph so we use as much memory as possible (reading from Zarr) + run_operation(tmp_path, executor, "sqrts", b, c, optimize_graph=False) + + # Internal functions @@ -327,20 +341,23 @@ def run_operation( tmp_path, executor, name, - result_array, - *, + *results, optimize_graph=True, optimize_function=None, ): - # result_array.visualize(f"cubed-{name}-unoptimized", optimize_graph=False, show_hidden=True) - # result_array.visualize(f"cubed-{name}", optimize_function=optimize_function) + # cubed.visualize( + # *results, filename=f"cubed-{name}-unoptimized", optimize_graph=False, show_hidden=True + # ) + # cubed.visualize( + # *results, filename=f"cubed-{name}", optimize_function=optimize_function + # ) hist = HistoryCallback() mem_warn = MemoryWarningCallback() memray = MemrayCallback() - # use store=None to write to temporary zarr - cubed.to_zarr( - result_array, - store=None, + # use None for each store to write to temporary zarr + cubed.store( + results, + (None,) * len(results), executor=executor, callbacks=[hist, mem_warn, memray], optimize_graph=optimize_graph, diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index 7b685cf8..d3feae77 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -19,6 +19,7 @@ simple_optimize_dag, ) from cubed.core.plan import arrays_to_plan +from cubed.tests.test_core import sqrts from cubed.tests.utils import TaskCounter @@ -981,6 +982,132 @@ def test_fuse_partial_reduce_binary(spec): assert_array_equal(result, 6 * np.ones((1, 2))) +# unary op followed by multiple outputs +# +# a -> a +# | / \ +# b c d +# / \ +# c d +# +def test_fuse_unary_op_and_multiple_outputs(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.positive(a) + c, d = sqrts(b) + + opt_fn = fuse_multiple_levels() + + cubed.visualize(c, d, optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a,), (c, d)) + plan = arrays_to_plan(c, d) + optimized_dag = plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + + c_result, d_result = cubed.compute(c, d, optimize_function=opt_fn) + assert_array_equal(c_result, np.ones((2, 2))) + assert_array_equal(d_result, -np.ones((2, 2))) + + +# multiple outputs followed by unary ops +# note: this is not yet implemented +# +# a -> a +# / \ / \ +# b c d e +# | | +# d e +# +def test_fuse_multiple_outputs_and_unary_op(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b, c = sqrts(a) + d = xp.negative(b) + e = xp.negative(c) + + opt_fn = fuse_multiple_levels() + + cubed.visualize(d, e, optimize_function=opt_fn) + + # # check structure of optimized dag + # expected_fused_dag = create_dag() + # add_placeholder_op(expected_fused_dag, (), (a,)) + # add_placeholder_op(expected_fused_dag, (a,), (d, e)) + # plan = arrays_to_plan(d, e) + # optimized_dag = plan.optimize(optimize_function=opt_fn).dag + # assert structurally_equivalent(optimized_dag, expected_fused_dag) + + d_result, e_result = cubed.compute(d, e, optimize_function=opt_fn) + assert_array_equal(d_result, -np.ones((2, 2))) + assert_array_equal(e_result, np.ones((2, 2))) + + +# multiple outputs diamond +# note: this is not yet implemented +# +# a -> a +# / \ | +# b c d +# \ / +# d +# +def test_fuse_multiple_outputs_diamond(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b, c = sqrts(a) + d = xp.add(b, c) + + opt_fn = fuse_multiple_levels() + + d.visualize(optimize_function=opt_fn) + + # # check structure of optimized dag + # expected_fused_dag = create_dag() + # add_placeholder_op(expected_fused_dag, (), (a,)) + # add_placeholder_op(expected_fused_dag, (a,), (d,)) + # optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + # assert structurally_equivalent(optimized_dag, expected_fused_dag) + + result = d.compute(optimize_function=opt_fn) + assert_array_equal(result, np.zeros((2, 2))) + + +# sibling fusion +# note: this is not yet implemented +# +# (notation is more explicit here - 'o' is an op) +# +# o -> o +# | | +# a a +# / \ | +# o o o +# | | / \ +# b c b c +# +def test_fuse_siblings(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.positive(a) + c = xp.negative(a) + + opt_fn = fuse_multiple_levels() + + cubed.visualize(b, c, optimize_function=opt_fn) + + # # check structure of optimized dag + # expected_fused_dag = create_dag() + # add_placeholder_op(expected_fused_dag, (), (a,)) + # add_placeholder_op(expected_fused_dag, (a,), (b, c)) + # plan = arrays_to_plan(b, c) + # optimized_dag = plan.optimize(optimize_function=opt_fn).dag + # assert structurally_equivalent(optimized_dag, expected_fused_dag) + + b_result, c_result = cubed.compute(b, c, optimize_function=opt_fn) + assert_array_equal(b_result, np.ones((2, 2))) + assert_array_equal(c_result, -np.ones((2, 2))) + + def test_fuse_only_optimize_dag(spec): a = xp.ones((2, 2), chunks=(2, 2), spec=spec) b = xp.negative(a)