Skip to content

[distributed] InputArgumentBase should not become "extra placeholders" #272

@kaushikcfd

Description

@kaushikcfd

The following program:

import pytato as pt
import numpy as np
from pytato.distributed import (staple_distributed_send, make_distributed_recv,
                                find_distributed_partition)

from pytato.partition import generate_code_for_partition

rank = 0
size = 2

img1 = pt.make_placeholder("img1", shape=(100, 3), dtype="float32")
scl_np = np.random.rand(3)
scl = pt.make_data_wrapper(scl_np, name="scl")

y1 = img1 * scl

img2 = staple_distributed_send(
    y1, dest_rank=(rank-1) % size,
    comm_tag=42,
    stapled_to=make_distributed_recv(shape=img1.shape, dtype=img1.dtype,
                                     src_rank=(rank+1) % size, comm_tag=42,
                                     axes=img1.axes))
out = img1 + img2*scl


distributed_parts = find_distributed_partition(
    pt.make_dict_of_named_arrays({"out": out}))
prg_per_partition = generate_code_for_partition(distributed_parts)

for _, prg in prg_per_partition.items():
    print(prg.program)

performs unnecessary copies (redundant copies have been annotated in the kernel)

---------------------------------------------------------------------------
KERNEL: _pt_kernel
---------------------------------------------------------------------------
ARGUMENTS:
_pt_dist_id_0: type: np:dtype('float64'), shape: (100, 3), dim_tags: (N1:stride:3, N0:stride:1) aspace: global
_pt_part_ph_id: type: np:dtype('float32'), shape: (100, 3), dim_tags: (N1:stride:3, N0:stride:1) aspace: global
_pt_part_ph_id_0: type: np:dtype('float64'), shape: (3), dim_tags: (N0:stride:1) aspace: global
img1: type: np:dtype('float32'), shape: (100, 3), dim_tags: (N1:stride:3, N0:stride:1), offset: <class 'loopy.kernel.data.auto'> aspace: global
scl: type: np:dtype('float64'), shape: (3), dim_tags: (N0:stride:1), offset: <class 'loopy.kernel.data.auto'> aspace: global
---------------------------------------------------------------------------
DOMAINS:
{ [_pt_part_ph_id_dim0, _pt_part_ph_id_dim1] : 0 <= _pt_part_ph_id_dim0 <= 99 and 0 <= _pt_part_ph_id_dim1 <= 2 }
{ [_pt_part_ph_id_0_dim0] : 0 <= _pt_part_ph_id_0_dim0 <= 2 }
{ [_pt_dist_id_0_dim0, _pt_dist_id_0_dim1] : 0 <= _pt_dist_id_0_dim0 <= 99 and 0 <= _pt_dist_id_0_dim1 <= 2 }
---------------------------------------------------------------------------
INSTRUCTIONS:
   for _pt_part_ph_id_dim1, _pt_part_ph_id_dim0
        # WARNING: UNNECESSARY COPY!!!_pt_part_ph_id[_pt_part_ph_id_dim0, _pt_part_ph_id_dim1] = img1[_pt_part_ph_id_dim0, _pt_part_ph_id_dim1]  {id=_pt_part_ph_id_store}
│  end _pt_part_ph_id_dim1, _pt_part_ph_id_dim0for _pt_part_ph_id_0_dim0
|     # WARNING UNNECESSARY COPY!!!
│↱   _pt_part_ph_id_0[_pt_part_ph_id_0_dim0] = scl[_pt_part_ph_id_0_dim0]  {id=_pt_part_ph_id_0_store}
││ end _pt_part_ph_id_0_dim0
││ for _pt_dist_id_0_dim0, _pt_dist_id_0_dim1
└└     _pt_dist_id_0[_pt_dist_id_0_dim0, _pt_dist_id_0_dim1] = _pt_part_ph_id[_pt_dist_id_0_dim0, _pt_dist_id_0_dim1]*_pt_part_ph_id_0[_pt_dist_id_0_dim1]  {id=_pt_dist_id_0_store}
   end _pt_dist_id_0_dim0, _pt_dist_id_0_dim1
---------------------------------------------------------------------------

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions