Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
"https://documen.tician.de/loopy/": None,
"https://documen.tician.de/sumpy/": None,
"https://documen.tician.de/islpy/": None,
"https://pyrsistent.readthedocs.io/en/latest/": None,
}

import sys
sys.PYTATO_BUILDING_SPHINX_DOCS = True

nitpick_ignore_regex = [
["py:class", r"numpy.(u?)int[\d]+"],
["py:class", r"pyrsistent.typing.(.+)"],

]
37 changes: 37 additions & 0 deletions examples/demo_distributed_node_duplication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
An example to demonstrate the behavior of
:func:`pytato.find_distrbuted_partition`. One of the key characteristic of the
partitioning routine is to recompute expressions that appear in the multiple
partitions but are not materialized.
"""
import pytato as pt
import numpy as np

size = 2
rank = 0

x1 = pt.make_placeholder("x1", shape=(10, 4), dtype=np.float64)
x2 = pt.make_placeholder("x2", shape=(10, 4), dtype=np.float64)
x3 = pt.make_placeholder("x3", shape=(10, 4), dtype=np.float64)
x4 = pt.make_placeholder("x4", shape=(10, 4), dtype=np.float64)


tmp1 = (x1 + x2).tagged(pt.tags.ImplStored())
tmp2 = tmp1 + x3
# "marking" *tmp2* so that its duplication can be clearly visualized.
tmp2 = tmp2.tagged(pt.tags.Named("tmp2"))
tmp3 = (2 * x4).tagged(pt.tags.ImplStored())
tmp4 = tmp2 + tmp3

recv = pt.staple_distributed_send(tmp4, dest_rank=(rank-1) % size, comm_tag=10,
stapled_to=pt.make_distributed_recv(
src_rank=(rank+1) % size, comm_tag=10, shape=(10, 4), dtype=int))

out = tmp2 + recv
result = pt.make_dict_of_named_arrays({"out": out})

partitions = pt.find_distributed_partition(result)

# Visualize *partitions* to see that each of the two partitions contains a node
# named 'tmp2'.
pt.show_dot_graph(partitions)
3 changes: 2 additions & 1 deletion pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@
staple_distributed_send,
find_distributed_partition,
number_distributed_tags,
execute_distributed_partition)
execute_distributed_partition,
)

from pytato.partition import generate_code_for_partition

Expand Down
109 changes: 107 additions & 2 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@
THE SOFTWARE.
"""

from typing import Mapping, Dict, Union, Set, Tuple, Any
from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet,
TYPE_CHECKING)
from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum,
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase)
IndexBase, IndexRemappingBase, InputArgumentBase,
ShapeType)
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
from pytato.loopy import LoopyCall

if TYPE_CHECKING:
from pytato.distributed import DistributedRecv, DistributedSendRefHolder

__doc__ = """
.. currentmodule:: pytato.analysis

Expand All @@ -40,13 +45,22 @@
.. autofunction:: is_einsum_similar_to_subscript

.. autofunction:: get_num_nodes

.. autoclass:: DirectPredecessorsGetter
"""


class NUserCollector(Mapper):
"""
A :class:`pytato.transform.CachedWalkMapper` that records the number of
times an array expression is a direct dependency of other nodes.

.. note::

- We do not consider the :class:`pytato.DistributedSendRefHolder`
a user of :attr:`pytato.DistributedSendRefHolder.send`. This is
because in a data flow sense, the send-ref holder does not use the
send's data.
"""
def __init__(self) -> None:
from collections import defaultdict
Expand Down Expand Up @@ -141,6 +155,20 @@ def _map_input_base(self, expr: InputArgumentBase) -> None:
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder
) -> None:
# Note: We do not consider 'expr.send.data' as a predecessor of *expr*,
# as there is no dataflow from *expr.send.data* to *expr*
self.nusers[expr.passthrough_data] += 1
self.rec(expr.passthrough_data)
self.rec(expr.send.data)

def map_distributed_recv(self, expr: DistributedRecv) -> None:
for dim in expr.shape:
if isinstance(dim, Array):
self.nusers[dim] += 1
self.rec(dim)


def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]:
"""
Expand Down Expand Up @@ -246,6 +274,83 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:
return True


# {{{ DirectPredecessorsGetter

class DirectPredecessorsGetter(Mapper):
"""
Mapper to get the
`direct predecessors
<https://en.wikipedia.org/wiki/Glossary_of_graph_theory#direct_predecessor>`__
of a node.

.. note::

We only consider the predecessors of a nodes in a data-flow sense.
"""
def _get_preds_from_shape(self, shape: ShapeType) -> FrozenSet[Array]:
return frozenset({dim for dim in shape if isinstance(dim, Array)})

def map_index_lambda(self, expr: IndexLambda) -> FrozenSet[Array]:
return (frozenset(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))

def map_stack(self, expr: Stack) -> FrozenSet[Array]:
return (frozenset(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_concatenate(self, expr: Concatenate) -> FrozenSet[Array]:
return (frozenset(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_einsum(self, expr: Einsum) -> FrozenSet[Array]:
return (frozenset(expr.args)
| self._get_preds_from_shape(expr.shape))

def map_loopy_call_result(self, expr: NamedArray) -> FrozenSet[Array]:
from pytato.loopy import LoopyCallResult, LoopyCall
assert isinstance(expr, LoopyCallResult)
assert isinstance(expr._container, LoopyCall)
return (frozenset(ary
for ary in expr._container.bindings.values()
if isinstance(ary, Array))
| self._get_preds_from_shape(expr.shape))

def _map_index_base(self, expr: IndexBase) -> FrozenSet[Array]:
return (frozenset([expr.array])
| frozenset(idx for idx in expr.indices
if isinstance(idx, Array))
| self._get_preds_from_shape(expr.shape))

map_basic_index = _map_index_base
map_contiguous_advanced_index = _map_index_base
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase
) -> FrozenSet[Array]:
return frozenset([expr.array])

map_roll = _map_index_remapping_base
map_axis_permutation = _map_index_remapping_base
map_reshape = _map_index_remapping_base

def _map_input_base(self, expr: InputArgumentBase) -> FrozenSet[Array]:
return self._get_preds_from_shape(expr.shape)

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(self, expr: DistributedRecv) -> FrozenSet[Array]:
return self._get_preds_from_shape(expr.shape)

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
) -> FrozenSet[Array]:
return frozenset([expr.passthrough_data])

# }}}


# {{{ NodeCountMapper

class NodeCountMapper(CachedWalkMapper):
Expand Down
Loading