-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Clean cross entropy #10280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean cross entropy #10280
Conversation
: dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {} | ||
|
||
HOSTDEVICE void operator()(size_t label_id) { | ||
auto x_is_true_offset = label_id * num_classes_ + label_[label_id]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe label_id
is not exact, sample_id
may be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
dx_[x_offset] = x_offset != x_is_true_offset | ||
? static_cast<T>(0) | ||
: -dy_[label_id] / x_[x_offset]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to write like this:
auto x_is_true_offset = label_id * num_classes_ + label_[label_id];
dx_[x_is_true_offset] = -dy_[label_id] / x_[x_is_true_offset]
and put setting dx_
to zero in here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is slower to use two kernels than one kernel.
… feature/add_stable_test_of_cross_entropy
…:reyoung/Paddle into feature/add_stable_test_of_cross_entropy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent!
Unify cross entropy CPU and GPU kernel by
platform::ForRange
functor. It reduce the lines of code ofcross_entropy_op