-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Add modified huber loss operator #3987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
}; | ||
|
||
template <typename T> | ||
struct ModifiedHuberLossForward { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么ModifiedHuberLossForward单独列出来,而backward却没有呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前向是因为CPU和GPU共用,但是BP时两者分离,而且CPU逻辑比较简单,所以直接写循环了,GPU实现用的thrust,所以单列了
} else if (inter_val_ptr[i] < 1) { | ||
x_grad_ptr[i] = -2 * (1 - inter_val_ptr[i]) * (2 * y_ptr[i] - 1) * | ||
out_grad_ptr[i]; | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个反向写的不对吧。92行和95行,都不需要再乘一个out_grad_ptr[i]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luotao1 在operator里,loss的写法和Paddle Layer的写法不一样,output_grad即使是1.0也是框架自动设置,或loss op外面自动设置的,所以需要乘out_grad_ptr[i]
|
||
PADDLE_ENFORCE_EQ(x->dims(), y->dims(), | ||
"Dimensions of X and Y must be the same."); | ||
PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
framework::arity(x->dims())
--> x->dims().size()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
"Dimensions of X and Y must be the same."); | ||
PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2, | ||
"Tensor rank of X must be 2."); | ||
PADDLE_ENFORCE_EQ(x->dims()[1], 1, "Second dimension of X must be 1."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--> The second
or The 2nd
for short.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "Input value of ModifiedHuberLossOp."); | ||
AddInput("Y", "Target labels of ModifiedHuberLossOp."); | ||
AddOutput("intermediate_val", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have the same naming style? e.g. IntermediateVal
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
def modified_huber_loss_forward(val): | ||
if val < -1: | ||
return -4 * a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is a
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a --> val. Done.
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-3923-c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Resolves #3923