Skip to content
Closed
2 changes: 2 additions & 0 deletions examples/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def main():

# Find the partition
outputs = pt.DictOfNamedArrays({"out": y})

pt.verify_distributed_dag_pre_partition(comm, outputs)
distributed_parts = find_distributed_partition(outputs)
distributed_parts, _ = number_distributed_tags(
comm, distributed_parts, base_tag=42)
Expand Down
4 changes: 3 additions & 1 deletion pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.distributed.partition import (
find_distributed_partition, DistributedGraphPart, DistributedGraphPartition)
from pytato.distributed.tags import number_distributed_tags
from pytato.distributed.verify import verify_distributed_partition
from pytato.distributed.verify import (verify_distributed_partition,
verify_distributed_dag_pre_partition)
from pytato.distributed.execute import execute_distributed_partition

from pytato.transform.lower_to_index_lambda import to_index_lambda
Expand Down Expand Up @@ -161,6 +162,7 @@ def set_debug_enabled(flag: bool) -> None:
"number_distributed_tags",
"execute_distributed_partition",
"verify_distributed_partition",
"verify_distributed_dag_pre_partition",

"generate_code_for_partition",

Expand Down
186 changes: 175 additions & 11 deletions pytato/distributed/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@
"""


from typing import Any, FrozenSet, Dict, Set, Optional, Sequence, TYPE_CHECKING
from typing import FrozenSet, Dict, Set, Optional, Sequence, TYPE_CHECKING, Union
from immutables import Map

import numpy as np

from pytato.distributed.nodes import CommTagType, DistributedRecv
from pytato.distributed.nodes import (CommTagType, DistributedRecv,
DistributedSendRefHolder, DistributedSend)
from pytato.partition import PartId
from pytato.distributed.partition import DistributedGraphPartition
from pytato.array import ShapeType
from pytato.transform import UsersCollector, ArrayOrNames
from pytato import DictOfNamedArrays

from pytools import UniqueNameGenerator

import attrs

Expand All @@ -58,9 +61,6 @@ class _SummarizedDistributedSend:
dest_rank: int
comm_tag: CommTagType

shape: ShapeType
dtype: np.dtype[Any]


@attrs.define(frozen=True)
class _DistributedPartId:
Expand All @@ -74,6 +74,12 @@ class _DistributedName:
name: str


@attrs.define(frozen=True)
class _DistributedNode:
rank: int
node: Union[DistributedSend, ArrayOrNames, _SummarizedDistributedSend]


@attrs.define(frozen=True)
class _SummarizedDistributedGraphPart:
pid: _DistributedPartId
Expand All @@ -89,6 +95,20 @@ def rank(self) -> int:
return self.pid.rank


@attrs.define(frozen=True)
class _SummarizedDistributedGraph:
rank: int
node_to_users: Dict[_DistributedNode, Set[_DistributedNode]]
input_name_to_recv_node: Dict[_DistributedName, DistributedRecv]
output_name_to_send_node: Dict[_DistributedName, _SummarizedDistributedSend]

def __hash__(self) -> int:
return (hash(self.rank)
^ hash(Map(self.node_to_users))
^ hash(Map(self.input_name_to_recv_node))
^ hash(Map(self.output_name_to_send_node)))


@attrs.define(frozen=True)
class _CommIdentifier:
src_rank: int
Expand Down Expand Up @@ -122,6 +142,152 @@ class MissingRecvError(DistributedPartitionVerificationError):
# }}}


# {{{ verify_distributed_dag_pre_partition

class _DistributedDAGGatherer(UsersCollector):
def __init__(self, dist_name_generator: UniqueNameGenerator, my_rank: int) \
-> None:
super().__init__()

self.name_generator = dist_name_generator
self.my_rank = my_rank

self.input_name_to_recv_node: Dict[str, DistributedRecv] = {}
self.output_name_to_send_node: Dict[str, _SummarizedDistributedSend] = {}

def map_distributed_recv(self, expr: DistributedRecv) -> None:
super().map_distributed_recv(expr)
new_name = self.name_generator()
self.input_name_to_recv_node[new_name] = expr

def map_distributed_send_ref_holder(
self, expr: DistributedSendRefHolder) -> None:
super().map_distributed_send_ref_holder(expr)
s = _SummarizedDistributedSend(
src_rank=self.my_rank,
dest_rank=expr.send.dest_rank,
comm_tag=expr.send.comm_tag)

new_name = self.name_generator()
self.output_name_to_send_node[new_name] = s


def verify_distributed_dag_pre_partition(mpi_communicator: mpi4py.MPI.Comm,
outputs: DictOfNamedArrays) -> None:
"""
.. warning::

This is an MPI-collective operation.
"""
my_rank = mpi_communicator.rank
root_rank = 0

ung = UniqueNameGenerator(forced_prefix="_pt_verify_dist_")

dg = _DistributedDAGGatherer(ung, my_rank)
dg(outputs)

def dist_send_to_summarized_dist_send(node:
Union[ArrayOrNames, _SummarizedDistributedSend, DistributedSend]) \
-> Union[ArrayOrNames, _SummarizedDistributedSend]:
if (not isinstance(node, DistributedSend)
and not isinstance(node, DistributedSendRefHolder)):
return node

if isinstance(node, DistributedSend):
return _SummarizedDistributedSend(
src_rank=my_rank,
dest_rank=node.dest_rank,
comm_tag=node.comm_tag,)
elif isinstance(node, DistributedSendRefHolder):
return _SummarizedDistributedSend(
src_rank=my_rank,
dest_rank=node.send.dest_rank,
comm_tag=node.send.comm_tag,)

summarized_dag = _SummarizedDistributedGraph(
rank=my_rank,
node_to_users={_DistributedNode(my_rank, k):
set((_DistributedNode(my_rank, dist_send_to_summarized_dist_send(n))
for n in v)) for k, v in dg.node_to_users.items()},
input_name_to_recv_node={_DistributedName(my_rank, name): recv
for name, recv in dg.input_name_to_recv_node.items()},
output_name_to_send_node={
_DistributedName(my_rank, name):
_SummarizedDistributedSend(
src_rank=my_rank,
dest_rank=send.dest_rank,
comm_tag=send.comm_tag,)
for name, send in dg.output_name_to_send_node.items()})

all_outputs = mpi_communicator.gather(summarized_dag, root=root_rank)

if my_rank == root_rank:
assert all_outputs

all_summarized_outputs = {
rank: rank_outputs
for rank, rank_outputs in enumerate(all_outputs)}

all_recvs: Set[_CommIdentifier] = set()

send_recv_deps: \
Dict[_DistributedNode, Set[_DistributedNode]] = {}

def add_send_recv_dep(recv: _DistributedNode,
send: _DistributedNode) -> None:
send_recv_deps.setdefault(recv, set()).add(send)

# {{{ gather information on senders

comm_id_to_sending_node = {}

for sumdag in all_summarized_outputs.values():
for sumsend in sumdag.output_name_to_send_node.values():
comm_id = _CommIdentifier(
src_rank=sumsend.src_rank,
dest_rank=sumsend.dest_rank,
comm_tag=sumsend.comm_tag)

if comm_id in comm_id_to_sending_node:
raise DuplicateSendError(
f"duplicate send for comm id: '{comm_id}'")
comm_id_to_sending_node[comm_id] = sumsend

# }}}

for sumdag in all_summarized_outputs.values():
send_recv_deps.update(sumdag.node_to_users)

for dname, dist_recv in sumdag.input_name_to_recv_node.items():
comm_id = _CommIdentifier(
src_rank=dist_recv.src_rank,
dest_rank=dname.rank,
comm_tag=dist_recv.comm_tag)

if comm_id in all_recvs:
raise DuplicateRecvError(f"Duplicate recv: '{comm_id}'")

all_recvs.add(comm_id)

# Add edges between sends and receives (cross-rank)
try:
sending_node = comm_id_to_sending_node[comm_id]
except KeyError:
raise MissingSendError(
f"no matching send for recv on '{comm_id}'")

add_send_recv_dep(_DistributedNode(comm_id.dest_rank, dist_recv),
_DistributedNode(comm_id.src_rank, sending_node))

from pytools.graph import compute_topological_order
compute_topological_order(send_recv_deps)

logger.info("verify_distributed_dag_pre_partition completed successfully.")

# }}}


# {{{ verify_distributed_partition

def verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm,
Expand Down Expand Up @@ -169,9 +335,7 @@ def verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm,
_SummarizedDistributedSend(
src_rank=my_rank,
dest_rank=send.dest_rank,
comm_tag=send.comm_tag,
shape=send.data.shape,
dtype=send.data.dtype)
comm_tag=send.comm_tag)
for name, send in part.output_name_to_send_node.items()})

# Gather the _SummarizedDistributedGraphPart's to rank 0
Expand Down
37 changes: 36 additions & 1 deletion test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def test_deterministic_partitioning():
# }}}


# {{{ test verify_distributed_partition
# {{{ test verify_distributed_partition and verify_distributed_dag_pre_partition

def test_verify_distributed_partition():
run_test_with_mpi(2, _do_verify_distributed_partition)
Expand All @@ -331,6 +331,13 @@ def _do_verify_distributed_partition(ctx_factory):
src_rank=(rank+1) % size, comm_tag=42, shape=(4, 4), dtype=int)

outputs = pt.make_dict_of_named_arrays({"out": y})

if rank == 0:
with pytest.raises(MissingSendError):
pt.verify_distributed_dag_pre_partition(comm, outputs)
else:
pt.verify_distributed_dag_pre_partition(comm, outputs)

distributed_parts = pt.find_distributed_partition(outputs)

if rank == 0:
Expand All @@ -350,6 +357,13 @@ def _do_verify_distributed_partition(ctx_factory):
outputs = pt.make_dict_of_named_arrays({"out": send})
distributed_parts = pt.find_distributed_partition(outputs)

if rank == 0:
# FIXME: this should raise
with pytest.raises(MissingRecvError):
pt.verify_distributed_dag_pre_partition(comm, outputs)
else:
pt.verify_distributed_dag_pre_partition(comm, outputs)

if rank == 0:
with pytest.raises(MissingRecvError):
pt.verify_distributed_partition(comm, distributed_parts)
Expand All @@ -368,6 +382,13 @@ def _do_verify_distributed_partition(ctx_factory):
src_rank=(rank+1) % size, comm_tag=42, shape=(4, 4), dtype=int))

outputs = pt.make_dict_of_named_arrays({"out": x+send})

if rank == 0:
with pytest.raises(MissingSendError):
pt.verify_distributed_dag_pre_partition(comm, outputs)
else:
pt.verify_distributed_dag_pre_partition(comm, outputs)

distributed_parts = pt.find_distributed_partition(outputs)

if rank == 0:
Expand All @@ -388,6 +409,13 @@ def _do_verify_distributed_partition(ctx_factory):
dest_rank=(rank-1) % size, comm_tag=42, stapled_to=x)

outputs = pt.make_dict_of_named_arrays({"out": send+send2})

if rank == 0:
with pytest.raises(DuplicateSendError):
pt.verify_distributed_dag_pre_partition(comm, outputs)
else:
pt.verify_distributed_dag_pre_partition(comm, outputs)

distributed_parts = pt.find_distributed_partition(outputs)

if rank == 0:
Expand All @@ -413,6 +441,13 @@ def _do_verify_distributed_partition(ctx_factory):
stapled_to=recv)

outputs = pt.make_dict_of_named_arrays({"out": send+send2})

if rank == 0:
with pytest.raises(MissingSendError):
pt.verify_distributed_dag_pre_partition(comm, outputs)
else:
pt.verify_distributed_dag_pre_partition(comm, outputs)
Comment on lines +445 to +449
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Find a way to make these less wordy.


distributed_parts = pt.find_distributed_partition(outputs)

if rank == 0:
Expand Down