Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR OpTest Fix No.14】 fix test_nce #60255

Merged
merged 9 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
'fused_scale_bias_add_relu',
'fused_dconv_drelu_dbn',
'fused_dot_product_attention',
'nce',
'lars_momentum',
'recv_v2',
'rnn_',
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,17 @@
data_type: param
optional: master_param, master_param_out

- op: nce
args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false)
output: Tensor(cost), Tensor(sample_logits), Tensor(sample_labels)
infer_meta:
func: NceInferMeta
kernel:
func: nce
data_type: input
optional: bias, sample_weight, custom_dist_probs, custom_dist_alias, custom_dist_alias_probs
backward: nce_grad

- op: number_count
args: (Tensor numbers, int upper_range)
output: Tensor(out)
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,18 @@
func : multiply_triple_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad

- backward_op : nce_grad
forward: nec (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false) -> Tensor(cost), Tensor(sample_logits), Tensor(sample_labels)
args: (Tensor input, Tensor label, Tensor bias, Tensor weight, Tensor sample_logits, Tensor sample_labels, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, Tensor cost_grad, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false)
output: Tensor(input_grad), Tensor(bias_grad), Tensor(weight_grad)
infer_meta:
func: NceGradInferMeta
param: [input, bias, weight]
kernel:
func: nce_grad
data_type: input
optional: bias, sample_weight, custom_dist_probs, custom_dist_alias, custom_dist_alias_probs

- backward_op : norm_grad
forward : norm (Tensor x, int axis, float epsilon, bool is_test) -> Tensor(out), Tensor(norm)
args : (Tensor x, Tensor norm, Tensor out_grad, int axis, float epsilon, bool is_test)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ const std::unordered_set<std::string> LegacyOpList = {
RowConvGradOp::name(),
SoftReluOp::name(),
SoftReluGradOp::name(),
NceOp::name(),
NceGradOp::name(),
CReduceMinOp::name()};

const std::unordered_set<std::string> OneDNNLegacyOpList = {};
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3510,6 +3510,13 @@
outputs :
out : Out

- op: nce
backward: nce_grad
inputs:
{input : Input, label : Label, weight : Weight, bias : Bias, sample_weight : SampleWeight, custom_dist_probs : CustomDistProbs, custom_dist_alias : CustomDistAlias, custom_dist_alias_probs : CustomDistAliasProbs}
outputs:
{cost : Cost, sample_logits : SampleLogits, sample_labels : SampleLabels}

- op: number_count
inputs :
{numbers: numbers}
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,33 @@ void NanmedianGradInferMeta(const MetaTensor& x,
x_grad->set_dtype(x.dtype());
}

void NceGradInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const MetaTensor& weight,
MetaTensor* input_grad,
MetaTensor* bias_grad,
MetaTensor* weight_grad

) {
auto x_dims = input.dims();
if (input_grad != nullptr) {
input_grad->set_dims(x_dims);
input_grad->set_dtype(input.dtype());
}

auto w_dims = weight.dims();
if (weight_grad) {
weight_grad->set_dims(w_dims);
weight_grad->set_dtype(weight.dtype());
}

auto bias_dims = bias.dims();
if (bias_grad) {
bias_grad->set_dims(bias_dims);
bias_grad->set_dtype(bias.dtype());
}
}

void NllLossGradInferMeta(const MetaTensor& x,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ void NanmedianGradInferMeta(const MetaTensor& x,
bool keep_dim,
MetaTensor* x_grad);

void NceGradInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const MetaTensor& weight,
MetaTensor* input_grad,
MetaTensor* bias_grad,
MetaTensor* weight_grad);

void NllLossGradInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
92 changes: 92 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3181,6 +3181,98 @@ void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
out->set_dtype(ins[0]->dtype());
}

void NceInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& sample_weight,
const MetaTensor& custom_dist_probs,
const MetaTensor& custom_dist_alias,
const MetaTensor& custom_dist_alias_probs,
int num_total_classes,
const std::vector<int>& custom_neg_classes,
int num_neg_samples,
int sampler,
int seed,
bool is_sparse,
bool remote_prefetch,
bool is_test,
MetaTensor* cost,
MetaTensor* sample_logits,
MetaTensor* sample_labels,
MetaConfig config) {
auto x_dims = input.dims();
auto label_dims = label.dims();
if (config.is_runtime || (x_dims[0] > 0 && label_dims[0] > 0)) {
PADDLE_ENFORCE_EQ(
x_dims[0],
label_dims[0],
phi::errors::InvalidArgument(
"The first dimension of Input(Input) and Input(Label) should be "
"equal in runtime. But received: Input(Input)'s shape = [%s] "
"with 1st dim = %d, Input(Label)'s shape = [%s] with 1st dim = "
"%d.",
x_dims,
x_dims[0],
label_dims,
label_dims[0]));
}
int num_true_classes =
static_cast<int>(label_dims.size() == 2 ? label_dims[1] : 1);
if (bias) {
PADDLE_ENFORCE_EQ(
weight.dims()[0],
bias.dims()[0],
phi::errors::InvalidArgument(
"The first dimension of Input(Weight) and Input(Bias) "
"should be equal. But received: Input(Weight)'s shape = [%s] "
"with 1st dim = %d, and Input(Bias)'s shape = [%s] with 1st dim "
"= %d.",
weight.dims(),
weight.dims()[0],
bias.dims(),
bias.dims()[0]));
}

PADDLE_ENFORCE_EQ(
num_total_classes,
weight.dims()[0],
phi::errors::InvalidArgument(
"The number of total classes should be equal to the first "
"dimension of Input(Weight). But received: Attr(num_total_classes) "
"= %d, Input(Weight)'s shape = [%s] with 1st dim = %d.",
num_total_classes,
weight.dims(),
weight.dims()[0]));
if (custom_neg_classes.size() > 0) {
PADDLE_ENFORCE_EQ(
custom_neg_classes.size(),
static_cast<size_t>(num_neg_samples),
phi::errors::InvalidArgument(
"The size of Attr(custom_neg_classes) should be equal "
"to the number of negative samples. But received: "
"custom_neg_classes.size() = %d, num_neg_samples = %d.",
custom_neg_classes.size(),
num_neg_samples));
}
// set dims of output(Out)
std::vector<int64_t> out_dims;
out_dims.push_back(x_dims[0]);
out_dims.push_back(1);
cost->set_dims(common::make_ddim(out_dims));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已设置

cost->set_dtype(DataType::FLOAT32);

if (!is_test) {
// set dims of output(SampleOut)
std::vector<int64_t> sample_out_dims;
sample_out_dims.push_back(x_dims[0]);
sample_out_dims.push_back(
(num_true_classes == -1) ? -1 : (num_neg_samples + num_true_classes));
sample_logits->set_dims(common::make_ddim(sample_out_dims));
sample_labels->set_dims(common::make_ddim(sample_out_dims));
}
}

void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois,
const MetaTensor& rois_num,
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,27 @@ void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out);

void NceInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
const MetaTensor& bias,
const MetaTensor& sample_weight,
const MetaTensor& custom_dist_probs,
const MetaTensor& custom_dist_alias,
const MetaTensor& custom_dist_alias_probs,
int num_total_classes,
const std::vector<int>& custom_neg_classes,
int num_neg_samples,
int sampler,
int seed,
bool is_sparse,
bool remote_prefetch,
bool is_test,
MetaTensor* cost,
MetaTensor* sample_logits,
MetaTensor* sample_labels,
MetaConfig config = MetaConfig());

void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois,
const MetaTensor& rois_num,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ test_multinomial_op
test_multiplex_op
test_mv_op
test_nanmedian
test_nce
test_nearest_interp_mkldnn_op
test_nearest_interp_v2_op
test_nextafter_op
Expand Down