Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a319455
mode
crazyxiaoxi Jul 26, 2024
86191e6
boxclip
crazyxiaoxi Jul 26, 2024
90aa9a6
test_op
crazyxiaoxi Jul 29, 2024
8a9ca8c
testop
crazyxiaoxi Jul 29, 2024
c25fede
margin_cross_entropy_op
crazyxiaoxi Jul 29, 2024
1190f7b
resolve conflict
crazyxiaoxi Jul 30, 2024
f14d22c
resolve conflict
crazyxiaoxi Jul 30, 2024
7a8d74f
resolve conflict
crazyxiaoxi Jul 30, 2024
a559de4
testop
crazyxiaoxi Jul 30, 2024
43ff6e9
testop
crazyxiaoxi Jul 30, 2024
43d5ce5
testop
crazyxiaoxi Jul 30, 2024
313c60a
tensor
crazyxiaoxi Aug 1, 2024
6380d80
fix
crazyxiaoxi Aug 1, 2024
a8ab8ac
fix
crazyxiaoxi Aug 2, 2024
1c7d7ec
test
crazyxiaoxi Aug 2, 2024
f4c5a12
fix problem
crazyxiaoxi Aug 5, 2024
f086511
Merge branch 'develop' into cinn/boxclip
crazyxiaoxi Aug 5, 2024
1c4e9aa
fix equal
crazyxiaoxi Aug 5, 2024
217614f
Update unused viriable
crazyxiaoxi Aug 6, 2024
69c72a5
static check
crazyxiaoxi Aug 6, 2024
591ae2e
rerun
crazyxiaoxi Aug 6, 2024
2ef4ce1
add cstr
crazyxiaoxi Aug 6, 2024
58f5989
Merge branch 'develop' into cinn/boxclip
crazyxiaoxi Aug 6, 2024
1f86f70
fix small problem
crazyxiaoxi Aug 6, 2024
d6b9b1d
equal constrain
crazyxiaoxi Aug 6, 2024
4e8ccc0
Merge branch 'develop' into cinn/margin
crazyxiaoxi Aug 6, 2024
1fecfe5
change
crazyxiaoxi Aug 6, 2024
520141f
Merge branch 'develop' into cinn/margin
crazyxiaoxi Aug 7, 2024
6f64112
codestyle
crazyxiaoxi Aug 7, 2024
a3d35d8
fix error
crazyxiaoxi Aug 8, 2024
d1b6a03
Merge branch 'develop' into cinn/mode
crazyxiaoxi Aug 9, 2024
569964f
Merge branch 'develop' into cinn/boxclip
crazyxiaoxi Aug 9, 2024
cdb5a9c
comment fix
crazyxiaoxi Aug 9, 2024
c24e47a
fix comment
crazyxiaoxi Aug 9, 2024
9f19580
fix
crazyxiaoxi Aug 9, 2024
7609187
CI presure
crazyxiaoxi Aug 9, 2024
b5c30b0
ci pressure
crazyxiaoxi Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,46 @@ bool MatmulOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool MarginCrossEntropyOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &logits_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &labels_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));

std::vector<symbol::DimExpr> logits_dims = logits_shape_or_data.shape();
std::vector<symbol::DimExpr> labels_dims = labels_shape_or_data.shape();

size_t logits_rank = logits_dims.size();
auto axis = logits_rank - 1;

for (size_t i = 0; i < logits_rank; i++) {
if (i != axis) {
infer_context->AddBroadcastableCstr(logits_dims[i], labels_dims[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是 Equal 约束吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

}
}

const auto one = symbol::DimExpr{1};

if (labels_dims.size() > 1) {
infer_context->AddEqualCstr(logits_dims[axis], one);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里逻辑和infer meta不太符合,和之前的相等约束一并修改后再合并吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(logits_dims)});

logits_dims[axis] = symbol::DimExpr(1);

infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(logits_dims)});

return true;
}

bool MvOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MarginCrossEntropy)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mv)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nextafter)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullBoxSparse)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3006,6 +3006,7 @@
func : margin_cross_entropy
data_type : logits
backward : margin_cross_entropy_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : masked_multihead_attention_
args : (Tensor x, Tensor cache_kv, Tensor bias, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, str compute_dtype = "default", float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0)
Expand Down