Skip to content

Commit a8dcab8

Browse files
authored
[Auto Parallel] Try to refine dist_api_gen (#73114)
* refine code * fix
1 parent 6ad2b01 commit a8dcab8

File tree

2 files changed

+67
-8
lines changed

2 files changed

+67
-8
lines changed

paddle/phi/api/generator/dist_api_gen.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,29 @@
581581
}}
582582
"""
583583

584+
CALCULATE_LOCAL_SHAPE_KERNEL_TEMPLATE = """
585+
586+
auto out_grad_shape = out_grad.dims();
587+
std::vector<{dtype}> local_kernel_shape;
588+
const auto& out_grad_dist_attr = {out_grad_dist_attr};
589+
for (int i = 0; i < out_grad_shape.size(); i++) {{
590+
if (out_grad_dist_attr.dims_mapping()[i] >= 0) {{
591+
{dtype} shape_i = out_grad_shape[i];
592+
int64_t dim = out_grad_dist_attr.dims_mapping()[i];
593+
int64_t mesh_dim = out_grad_dist_attr.process_mesh().shape()[dim];
594+
// TODO: Support aliquant condition.
595+
PADDLE_ENFORCE(shape_i % mesh_dim == 0,
596+
common::errors::InvalidArgument(
597+
"{op_name} only support local shape dim is divisible "
598+
"by the mesh dim, however local_kernel_shape[%lld] is %lld "
599+
"and shard mesh dims is %lld.", i, shape_i, mesh_dim));
600+
local_kernel_shape.push_back(shape_i / mesh_dim);
601+
}} else {{
602+
local_kernel_shape.push_back(out_grad_shape[i]);
603+
}}
604+
}}
605+
"""
606+
584607
# BaseAPI members:
585608
# inputs:
586609
# names : [], list of input names
@@ -1755,7 +1778,7 @@ def generate_infer_meta_code(self) -> str:
17551778

17561779
return output_decl_code + infer_meta_code
17571780

1758-
def generate_kernel_call_code(self) -> str:
1781+
def generate_kernel_call_code(self, is_forward=True) -> str:
17591782
dense_input_trans_map = {
17601783
'const Tensor&': 'const phi::DenseTensor&',
17611784
'const std::vector<Tensor>&': 'const std::vector<const phi::DenseTensor*>&',
@@ -1773,6 +1796,7 @@ def generate_kernel_call_code(self) -> str:
17731796
kernel_args_type_list = ['const phi::DeviceContext&']
17741797

17751798
attr_names = self.attrs['names']
1799+
pure_kernel_args = self.kernel['param']
17761800
kernel_args = self.kernel['param']
17771801
if kernel_args is None:
17781802
kernel_args = input_names + attr_names
@@ -1803,7 +1827,14 @@ def generate_kernel_call_code(self) -> str:
18031827
kernel_args_type_list.append('const phi::IntArray&')
18041828
# TODO(GhostScreaming): kernel like reshape need calculate local_shape
18051829
if self.infer_meta['local_shape'] is not None:
1806-
arg = 'phi::IntArray(local_shape)'
1830+
if is_forward or (
1831+
pure_kernel_args is not None
1832+
and self.infer_meta['local_shape']
1833+
not in pure_kernel_args
1834+
):
1835+
arg = 'phi::IntArray(local_shape)'
1836+
else:
1837+
arg = 'phi::IntArray(local_kernel_shape)'
18071838
else:
18081839
arg = 'phi::IntArray(' + arg + ')'
18091840
elif 'vector<phi::Scalar>' in self.attrs['attr_info'][arg][0]:
@@ -1818,9 +1849,15 @@ def generate_kernel_call_code(self) -> str:
18181849
self.attrs['attr_info'][arg][0]
18191850
)
18201851
# calculate local_shape for expand_as
1821-
# TODO(ooooo): bwd reuse this function to kernel, but actually the local_shape isn't same meaning.
18221852
if self.infer_meta['local_shape'] is not None:
1823-
arg = "local_shape"
1853+
if is_forward or (
1854+
pure_kernel_args is not None
1855+
and self.infer_meta['local_shape']
1856+
not in pure_kernel_args
1857+
):
1858+
arg = 'local_shape'
1859+
else:
1860+
arg = 'local_kernel_shape'
18241861
input_args.append(arg)
18251862
elif isinstance(arg, bool):
18261863
input_args.append(str(arg).lower())

paddle/phi/api/generator/dist_bw_api_gen.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def generate_output_creation_code(self) -> str:
288288

289289
return output_creation_code
290290

291-
def generate_bw_infer_local_shape_code(self) -> str:
291+
def generate_bw_infer_local_shape_code(self, need_kernel=False):
292292
arg_name = self.infer_meta['local_shape']
293293
assert arg_name in self.outputs['names'], (
294294
f"Auto Parallel will calculate local_shape for {arg_name} "
@@ -304,7 +304,7 @@ def generate_bw_infer_local_shape_code(self) -> str:
304304
self.outputs['names'].index(arg_name)
305305
]
306306
shape_type = self.get_shape_type(fw_attrs['attr_info'])
307-
return dist_api_gen.CALCULATE_LOCAL_SHAPE_TEMPLATE.format(
307+
return_code = dist_api_gen.CALCULATE_LOCAL_SHAPE_TEMPLATE.format(
308308
out_name=dist_out_name,
309309
out_dist_attr=(
310310
"PADDLE_GET_CONST(phi::distributed::TensorDistAttr, spmd_info.second[0]);"
@@ -314,6 +314,20 @@ def generate_bw_infer_local_shape_code(self) -> str:
314314
dtype=shape_type,
315315
op_name=self.kernel['func'][0],
316316
)
317+
if need_kernel:
318+
return (
319+
dist_api_gen.CALCULATE_LOCAL_SHAPE_KERNEL_TEMPLATE.format(
320+
out_grad_dist_attr=(
321+
"PADDLE_GET_CONST(phi::distributed::TensorDistAttr, spmd_info.first[1]);"
322+
if self.infer_meta['spmd_rule']
323+
else "phi::distributed::TensorDistAttr(common::vectorize(out_grad.dims()))"
324+
),
325+
dtype=shape_type,
326+
op_name=self.kernel['func'][0],
327+
)
328+
+ return_code
329+
)
330+
return return_code
317331

318332
def generate_infer_meta_code(self) -> str:
319333
(
@@ -347,7 +361,15 @@ def generate_infer_meta_code(self) -> str:
347361
)
348362
# TODO(GhostScreaming): kernel like reshape need calculate local_shape
349363
if self.infer_meta['local_shape'] is not None:
350-
infer_meta_code += self.generate_bw_infer_local_shape_code()
364+
if (
365+
self.kernel['param'] is not None
366+
and self.infer_meta['local_shape'] not in self.kernel['param']
367+
):
368+
infer_meta_code += self.generate_bw_infer_local_shape_code()
369+
else:
370+
infer_meta_code += self.generate_bw_infer_local_shape_code(
371+
need_kernel=True
372+
)
351373
infer_meta_code += SET_LOCAL_SHAPE_TEMPLATE.format(
352374
meta_tensor="meta_" + self.dense_output_args[0]
353375
)
@@ -462,7 +484,7 @@ def generate_auto_parallel_branch(self) -> str:
462484
)
463485
)
464486
infer_meta_code = self.generate_infer_meta_code()
465-
kernel_call_code = self.generate_kernel_call_code()
487+
kernel_call_code = self.generate_kernel_call_code(is_forward=False)
466488
fallback_code = self.generate_fallback_code()
467489
reshard_output_code = self.generate_reshard_output_code()
468490
return_code = self.generate_return_code()

0 commit comments

Comments
 (0)