Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f7923b8
run_test_with_mpi: Allow passing extra_env_vars
inducer Jan 13, 2023
38e25f7
New distributed-memory DAG partitioner
inducer Nov 22, 2022
5f2ee88
TMP: Towards getting the new partitioner ready
inducer Jan 10, 2023
255f98c
Add a test for Kaushik's distributed MWE
kaushikcfd Jan 12, 2023
d6798a9
Finish removing pytato.partition
inducer Jan 12, 2023
065c320
Do not include received arrays in part outputs
inducer Jan 12, 2023
f8dacd3
Rework test_deterministic_partitioning
inducer Jan 13, 2023
576f56a
remove assumption that arrays are sent only once
majosm Jan 13, 2023
fb4e501
remove assumption that output arrays are unique
majosm Jan 13, 2023
88e7c5c
disable partition disjointness check
majosm Jan 13, 2023
2bd86bc
add missing default init for extra_env_vars
majosm Jan 18, 2023
f1f184d
remove pu.db call
majosm Jan 18, 2023
232a2f6
add test for dag with duplicated output array
majosm Jan 18, 2023
df4670e
add test for dag with a receive as an output
majosm Jan 18, 2023
da52e71
add test for dag with multiple send nodes per sent array
majosm Jan 18, 2023
e99b4c0
fix mypy errors
majosm Jan 18, 2023
3b95b54
fix docs
majosm Jan 18, 2023
17ede0b
make distributed example work serially (for CI)
majosm Jan 18, 2023
ca9205b
make get_dot_graph_from_partition work with duplicated computation
majosm Jan 30, 2023
9005647
Merge branch 'partition-dot-graph-duplicated-computation' into dist-m…
majosm Feb 1, 2023
61c8671
fix handling of materialized arrays that are part outputs
majosm Jan 31, 2023
3d80957
make sure placeholders are unique
majosm Jan 31, 2023
e805b58
fix typo
majosm Jan 31, 2023
1a87ff7
put x on device
majosm Jan 31, 2023
03304c5
fix and simplify handling of part output generation
majosm Jan 31, 2023
6f68c06
add test for materialized arrays promoted to part outputs
majosm Jan 31, 2023
42982b9
make partitioning deterministic
majosm Feb 6, 2023
5b9be2c
add test for periodic communication
majosm Feb 7, 2023
7c3f47d
enhance _OrderedSet type annotations
majosm Feb 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions doc/dag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ Stringifying Expression Graphs

.. _partitioning:

Partitioning Array Expression Graphs
====================================

.. automodule:: pytato.partition

.. _distributed:

Support for Distributed-Memory/Message Passing
Expand Down
5 changes: 0 additions & 5 deletions doc/design.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,6 @@ Reserved Identifiers
names of :class:`~pytato.array.DataWrapper` arguments that are
not supplied by the user.

- ``_pt_part_ph``: Used to automatically generate identifiers for
names of :class:`~pytato.array.Placeholder` that represent data
transport across parts of a partitioned DAG
(cf. :func:`~pytato.partition.find_partition`).

- ``_pt_dist``: Used to automatically generate identifiers for
placeholders at distributed communication boundaries.

Expand Down
55 changes: 41 additions & 14 deletions examples/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,31 @@ def main():
x_in = rng.integers(100, size=(4, 4))
x = pt.make_data_wrapper(x_in)

mytag = (main, "x")
halo = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=mytag,
stapled_to=make_distributed_recv(
src_rank=(rank+1) % size, comm_tag=mytag, shape=(4, 4), dtype=int))

y = x+halo
if size > 1:
mytag_x = (main, "x")
x_plus = staple_distributed_send(x, dest_rank=(rank-1) % size,
comm_tag=mytag_x, stapled_to=make_distributed_recv(
src_rank=(rank+1) % size, comm_tag=mytag_x, shape=(4, 4),
dtype=int))

y = x+x_plus

mytag_y = (main, "y")
y_plus = staple_distributed_send(y, dest_rank=(rank-1) % size,
comm_tag=mytag_y, stapled_to=make_distributed_recv(
src_rank=(rank+1) % size, comm_tag=mytag_y, shape=(4, 4),
dtype=int))

z = y+y_plus
else:
# Self-sends aren't currently supported
y = x+x
z = y+y

# Find the partition
outputs = pt.make_dict_of_named_arrays({"out": y})
distributed_parts = find_distributed_partition(outputs)
outputs = pt.make_dict_of_named_arrays({"out": z})
distributed_parts = find_distributed_partition(comm, outputs)

distributed_parts, _ = number_distributed_tags(
comm, distributed_parts, base_tag=42)
prg_per_partition = generate_code_for_partition(distributed_parts)
Expand All @@ -39,23 +54,35 @@ def main():
from pytato.visualization import show_dot_graph
show_dot_graph(distributed_parts)

# Sanity check
from pytato.visualization import get_dot_graph_from_partition
get_dot_graph_from_partition(distributed_parts)
if 0:
# Sanity check
from pytato.visualization import get_dot_graph_from_partition
get_dot_graph_from_partition(distributed_parts)

# Execute the distributed partition
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)

pt.verify_distributed_partition(comm, distributed_parts)

context = execute_distributed_partition(distributed_parts, prg_per_partition,
queue, comm)

final_res = context["out"].get(queue)
# FIXME?
if comm.size > 1:
final_res = context["out"].get(queue)
else:
final_res = context["out"]

comm.isend(x_in, dest=(rank-1) % size, tag=42)
ref_halo = comm.recv(source=(rank+1) % size, tag=42)
ref_x_plus = comm.recv(source=(rank+1) % size, tag=42)

ref_y_in = x_in + ref_x_plus

comm.isend(ref_y_in, dest=(rank-1) % size, tag=43)
ref_y_plus = comm.recv(source=(rank+1) % size, tag=43)

ref_res = x_in + ref_halo
ref_res = ref_y_in + ref_y_plus

np.testing.assert_allclose(ref_res, final_res)

Expand Down
65 changes: 0 additions & 65 deletions examples/partition.py

This file was deleted.

8 changes: 3 additions & 5 deletions pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,15 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.distributed.partition import (
find_distributed_partition, DistributedGraphPart, DistributedGraphPartition)
from pytato.distributed.tags import number_distributed_tags
from pytato.distributed.execute import (
generate_code_for_partition, execute_distributed_partition)
from pytato.distributed.verify import verify_distributed_partition
from pytato.distributed.execute import execute_distributed_partition

from pytato.transform.lower_to_index_lambda import to_index_lambda
from pytato.transform.remove_broadcasts_einsum import (
rewrite_einsums_with_no_broadcasts)
from pytato.transform.metadata import unify_axes_tags

from pytato.partition import generate_code_for_partition

__all__ = (
"dtype",

Expand Down Expand Up @@ -161,11 +160,10 @@ def set_debug_enabled(flag: bool) -> None:
"find_distributed_partition",

"number_distributed_tags",
"generate_code_for_partition",
"execute_distributed_partition",
"verify_distributed_partition",

"generate_code_for_partition",

"to_index_lambda",

"rewrite_einsums_with_no_broadcasts",
Expand Down
53 changes: 38 additions & 15 deletions pytato/distributed/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
THE SOFTWARE.
"""

from typing import Any, Dict, Hashable, Tuple, Optional, TYPE_CHECKING
from typing import Any, Dict, Hashable, Tuple, Optional, TYPE_CHECKING, Mapping


from pytato.array import make_dict_of_named_arrays
from pytato.target import BoundProgram
from pytato.scalar_expr import INT_CLASSES

Expand All @@ -42,7 +43,7 @@
from pytato.distributed.nodes import (
DistributedRecv, DistributedSend)
from pytato.distributed.partition import (
DistributedGraphPartition, DistributedGraphPart)
DistributedGraphPartition, DistributedGraphPart, PartId)

import logging
logger = logging.getLogger(__name__)
Expand All @@ -52,6 +53,28 @@
import mpi4py.MPI


# {{{ generate_code_for_partition

def generate_code_for_partition(partition: DistributedGraphPartition) \
-> Mapping[PartId, BoundProgram]:
"""Return a mapping of partition identifiers to their
:class:`pytato.target.BoundProgram`."""
from pytato import generate_loopy
part_id_to_prg = {}

for part in sorted(partition.parts.values(),
key=lambda part_: sorted(part_.output_names)):
d = make_dict_of_named_arrays(
{var_name: partition.var_name_to_result[var_name]
for var_name in part.output_names
})
part_id_to_prg[part.pid] = generate_loopy(d)

return part_id_to_prg

# }}}


# {{{ distributed execute

def _post_receive(mpi_communicator: mpi4py.MPI.Comm,
Expand Down Expand Up @@ -88,19 +111,18 @@ def execute_distributed_partition(

from mpi4py import MPI

if len(partition.parts) != 1:
if any(part.name_to_recv_node for part in partition.parts.values()):
recv_names_tup, recv_requests_tup, recv_buffers_tup = zip(*[
(name,) + _post_receive(mpi_communicator, recv)
for part in partition.parts.values()
for name, recv in part.input_name_to_recv_node.items()])
for name, recv in part.name_to_recv_node.items()])
recv_names = list(recv_names_tup)
recv_requests = list(recv_requests_tup)
recv_buffers = list(recv_buffers_tup)
del recv_names_tup
del recv_requests_tup
del recv_buffers_tup
else:
# Only a single partition, no recv requests exist
recv_names = []
recv_requests = []
recv_buffers = []
Expand Down Expand Up @@ -146,13 +168,14 @@ def exec_ready_part(part: DistributedGraphPart) -> None:

context.update(result_dict)

for name, send_node in part.output_name_to_send_node.items():
# FIXME: pytato shouldn't depend on pyopencl
if isinstance(context[name], np.ndarray):
data = context[name]
else:
data = context[name].get(queue)
send_requests.append(_mpi_send(mpi_communicator, send_node, data))
for name, send_nodes in part.name_to_send_nodes.items():
for send_node in send_nodes:
# FIXME: pytato shouldn't depend on pyopencl
if isinstance(context[name], np.ndarray):
data = context[name]
else:
data = context[name].get(queue)
send_requests.append(_mpi_send(mpi_communicator, send_node, data))

pids_executed.add(part.pid)
pids_to_execute.remove(part.pid)
Expand All @@ -171,8 +194,8 @@ def wait_for_some_recvs() -> None:
buf = recv_buffers.pop(idx)

# FIXME: pytato shouldn't depend on pyopencl
import pyopencl as cl
context[name] = cl.array.to_device(queue, buf, allocator=allocator)
import pyopencl.array as cl_array
context[name] = cl_array.to_device(queue, buf, allocator=allocator)
recv_names_completed.add(name)

# {{{ main loop
Expand All @@ -182,7 +205,7 @@ def wait_for_some_recvs() -> None:
for pid in pids_to_execute
# FIXME: Only O(n**2) altogether. Nobody is going to notice, right?
if partition.parts[pid].needed_pids <= pids_executed
and (set(partition.parts[pid].input_name_to_recv_node)
and (set(partition.parts[pid].name_to_recv_node)
<= recv_names_completed)}
for pid in ready_pids:
part = partition.parts[pid]
Expand Down
Loading