Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1935,12 +1935,37 @@ bool SigmoidCrossEntropyWithLogits_OpInferSymbolicShape(
// return SyncBatchNormOpInferSymbolicShape(op, infer_context);
// }

// bool TdmSamplerOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool TdmSamplerOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
std::vector<symbol::DimExpr> x_dims = x_shape_or_data.shape();

bool output_positive =
op->attribute<pir::BoolAttribute>("output_positive").data();
std::vector<int> neg_samples_num_list =
paddle::dialect::details::GetVectorAttr<int>(op, "neg_samples_num_list");

int64_t sample_res_length = 0;
for (int sample_nums : neg_samples_num_list) {
sample_res_length += sample_nums + static_cast<int64_t>(output_positive);
}

symbol::DimExpr batch_size = x_dims[0];
symbol::DimExpr sample_res_dim(sample_res_length);

std::vector<symbol::DimExpr> output_dims = {batch_size, sample_res_dim};
symbol::TensorShapeOrDataDimExprs output_shape(output_dims);

infer_context->SetShapeOrDataForValue(
op->result(0), symbol::ShapeOrDataDimExprs{output_shape});
infer_context->SetShapeOrDataForValue(
op->result(1), symbol::ShapeOrDataDimExprs{output_shape});
infer_context->SetShapeOrDataForValue(
op->result(2), symbol::ShapeOrDataDimExprs{output_shape});

return true;
}

bool TrilinearInterpOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(SigmoidCrossEntropyWithLogits_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(SyncBatchNorm)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(SyncBatchNorm_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(SaveCombine)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(TdmSampler)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TdmSampler)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(HsigmoidLoss)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(ViterbiDecode)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4681,7 +4681,7 @@
func : tdm_sampler
data_type : x
optional : labels
# interfaces : paddle::dialect::InferSymbolicShapeInterface
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : temporal_shift
args : (Tensor x, int seg_num, float shift_ratio = 0.25f, str data_format = "NCHW")
Expand Down