581
581
}}
582
582
"""
583
583
584
+ CALCULATE_LOCAL_SHAPE_KERNEL_TEMPLATE = """
585
+
586
+ auto out_grad_shape = out_grad.dims();
587
+ std::vector<{dtype}> local_kernel_shape;
588
+ const auto& out_grad_dist_attr = {out_grad_dist_attr};
589
+ for (int i = 0; i < out_grad_shape.size(); i++) {{
590
+ if (out_grad_dist_attr.dims_mapping()[i] >= 0) {{
591
+ {dtype} shape_i = out_grad_shape[i];
592
+ int64_t dim = out_grad_dist_attr.dims_mapping()[i];
593
+ int64_t mesh_dim = out_grad_dist_attr.process_mesh().shape()[dim];
594
+ // TODO: Support aliquant condition.
595
+ PADDLE_ENFORCE(shape_i % mesh_dim == 0,
596
+ common::errors::InvalidArgument(
597
+ "{op_name} only support local shape dim is divisible "
598
+ "by the mesh dim, however local_kernel_shape[%lld] is %lld "
599
+ "and shard mesh dims is %lld.", i, shape_i, mesh_dim));
600
+ local_kernel_shape.push_back(shape_i / mesh_dim);
601
+ }} else {{
602
+ local_kernel_shape.push_back(out_grad_shape[i]);
603
+ }}
604
+ }}
605
+ """
606
+
584
607
# BaseAPI members:
585
608
# inputs:
586
609
# names : [], list of input names
@@ -1755,7 +1778,7 @@ def generate_infer_meta_code(self) -> str:
1755
1778
1756
1779
return output_decl_code + infer_meta_code
1757
1780
1758
- def generate_kernel_call_code (self ) -> str :
1781
+ def generate_kernel_call_code (self , is_forward = True ) -> str :
1759
1782
dense_input_trans_map = {
1760
1783
'const Tensor&' : 'const phi::DenseTensor&' ,
1761
1784
'const std::vector<Tensor>&' : 'const std::vector<const phi::DenseTensor*>&' ,
@@ -1773,6 +1796,7 @@ def generate_kernel_call_code(self) -> str:
1773
1796
kernel_args_type_list = ['const phi::DeviceContext&' ]
1774
1797
1775
1798
attr_names = self .attrs ['names' ]
1799
+ pure_kernel_args = self .kernel ['param' ]
1776
1800
kernel_args = self .kernel ['param' ]
1777
1801
if kernel_args is None :
1778
1802
kernel_args = input_names + attr_names
@@ -1803,7 +1827,14 @@ def generate_kernel_call_code(self) -> str:
1803
1827
kernel_args_type_list .append ('const phi::IntArray&' )
1804
1828
# TODO(GhostScreaming): kernel like reshape need calculate local_shape
1805
1829
if self .infer_meta ['local_shape' ] is not None :
1806
- arg = 'phi::IntArray(local_shape)'
1830
+ if is_forward or (
1831
+ pure_kernel_args is not None
1832
+ and self .infer_meta ['local_shape' ]
1833
+ not in pure_kernel_args
1834
+ ):
1835
+ arg = 'phi::IntArray(local_shape)'
1836
+ else :
1837
+ arg = 'phi::IntArray(local_kernel_shape)'
1807
1838
else :
1808
1839
arg = 'phi::IntArray(' + arg + ')'
1809
1840
elif 'vector<phi::Scalar>' in self .attrs ['attr_info' ][arg ][0 ]:
@@ -1818,9 +1849,15 @@ def generate_kernel_call_code(self) -> str:
1818
1849
self .attrs ['attr_info' ][arg ][0 ]
1819
1850
)
1820
1851
# calculate local_shape for expand_as
1821
- # TODO(ooooo): bwd reuse this function to kernel, but actually the local_shape isn't same meaning.
1822
1852
if self .infer_meta ['local_shape' ] is not None :
1823
- arg = "local_shape"
1853
+ if is_forward or (
1854
+ pure_kernel_args is not None
1855
+ and self .infer_meta ['local_shape' ]
1856
+ not in pure_kernel_args
1857
+ ):
1858
+ arg = 'local_shape'
1859
+ else :
1860
+ arg = 'local_kernel_shape'
1824
1861
input_args .append (arg )
1825
1862
elif isinstance (arg , bool ):
1826
1863
input_args .append (str (arg ).lower ())
0 commit comments