Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/ir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
{def_primitive}
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@ set(op_backward_yaml_file1
set(op_backward_yaml_file2
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml
)
set(fused_op_forward_yaml_file
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml
)
set(fused_op_backward_yaml_file
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml
)
set(op_yaml_file3
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.yaml)

set(op_yaml_files
${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3}
${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3}
)
set(op_namespace paddle,dialect)
set(dialect_name pd)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
output: Tensor(out), Tensor(seq_lod), Tensor(max_seq_len)
infer_meta :
func: EmbeddingWithEltwiseAddXPUInferMeta
param : [ids, tables, mask]
kernel:
func: embedding_with_eltwise_add_xpu
data_type: tables
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,11 @@ void FusedMultiTransformerXpuInferMeta(
const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
const std::vector<const MetaTensor*>& gather_index,
const MetaTensor& rotary_pos_emb,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥这个PR会带上该算子的修改呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的修改已经与@zyfncg 确认,修改原因是:
(1)在算子定义中后5个输入参数(如rotary_pos_emb)的类型是 Tensor,而不是 vector,因此这里将 infer_meta 从const std::vector<const MetaTensor*>& 规范为const MetaTensor&;
(2)此外,后5个输入参数在 infer-meta 函数中也没有被用到。

const MetaTensor& time_step,
const MetaTensor& seq_lengths,
const MetaTensor& src_mask,
const MetaTensor& gather_index,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ void FusedMultiTransformerXpuInferMeta(
const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step,
const std::vector<const MetaTensor*>& seq_lengths,
const std::vector<const MetaTensor*>& src_mask,
const std::vector<const MetaTensor*>& gather_index,
const MetaTensor& rotary_pos_emb,
const MetaTensor& time_step,
const MetaTensor& seq_lengths,
const MetaTensor& src_mask,
const MetaTensor& gather_index,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
Expand Down