-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rfc: dropout primitive attribute #1708
base: rfcs
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the RFC and prototype. Could you clarify the parameters for the attribute. For example is the memory descriptor for output masks or input mask ? I guess I see two ways:
- user might pass probability and the primitive computes the mask
- user might passes the mask and the primitive just applies it.
If you have training in mind, I wonder if we need to support both (on forward pass, generate the mask, on backward pass consume it).
Other questions:
- what kind of randomness is required on oneDNN implementations (how much entropy?, what distribution? ...)?
- what would be the plan to validate this attribute (the masking part should be fine, talking about randomness)
- we might want to add a knob for user to set the random seed, in order to be able to reproduce runs.
rfcs/20230818-Dropout/README.md
Outdated
/// otherwise. | ||
dnnl_status_t DNNL_API dnnl_primitive_attr_get_dropout( | ||
const_dnnl_primitive_attr_t attr, | ||
float *p, const_dnnl_memory_desc_t *drop_desc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly that, if convolution has the dropout attribute set, on forward we would:
- compute the dropout mask
- apply the mask to the destination
- and write the mask to a new output memory
The question is how would that mask be applied on backward? For example, we have Conv -> Dropout -> Relu.
Do we expect user to pass the dropout mask to the Relu backward computation or the Conv backward computation? I would think the former would be simpler to implement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, mentioned algorithm is what we expect from this primitive attiribute - compute and apply mask.
On backward pass mask from forward will be multiplied by input(see native_dropout_backward from PyTorch)
The main idea was to move mask generation inside of oneDNN. Otherwise user can just use binary_mul post op to apply existing mask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. That clarifies the scope of the rfc (the new attribute is for fwd only).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, what datatypes is oneDNN expected to support for the mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, we test only float32. I think in future, probably bfloat16 support will be also needed.
rfcs/20230818-Dropout/README.md
Outdated
/// Sets probability for drop-out primitive attribute. | ||
/// | ||
/// @param attr Primitive attributes. | ||
/// @param p Drop-out probability |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the probability needed at creation time, or can we take it at execution time?
In particular, if this parameter changes for each execution of a given primitive, passing probability at execution time would allow to increase primitive cache hit rate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our models/benchmarks we have same probability for all layers, but I think, that it is possible to make this runtime parameter.
c6291de
to
d6fa6a5
Compare
From the POC, it seems that in ref path the dropout is applied after post-ops and not before. Is that intentional? |
No, I'll change it. Thank you for finding it! |
rfcs/20230818-Dropout/README.md
Outdated
/// @returns #dnnl_success on success and a status describing the error | ||
/// otherwise. | ||
dnnl_status_t DNNL_API dnnl_primitive_attr_set_dropout( | ||
dnnl_primitive_attr_t attr, uint8_t enable_drop, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to remove enable_drop
from both APIs:
- In getter rely on md returned from the API call. If it is a zero md, then dropout is not set by user.
- In setter -- once user called
dnnl_primitive_attr_set_dropout(attr, mask)
dropout is set.
The current API is a bit confusing when user calls dnnl_primitive_attr_set_dropout(attr, **false**, mask_desc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@igorsafo , thanks for you suggestion!
I updated README and PoC
and runtime dropout arguments: output mask, which can be used in backward pass, | ||
dropout probability and seed. | ||
|
||
```c |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(random spot): We might need to introduce a standalone Dropout primitive to support frrameworks like ONNX that registers operations supported by the backend. In the current proposal Dropout functionality will be limited to a few patterns, but the rest patterns will not be able to implement Dropout functionality using oneDNN. Please double check with the frameworks if the solution works for them as well.
+@georgen117
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the question here is "is there a benefit to support dropout primitive in oneDNN"?
For the case of ONNX, isn't it ok if the oneDNN provider provides the dropout operation without using oneDNN, but uses oneDNN when dropout fusion occurs?
d6fa6a5
to
a6ffd4d
Compare
rfcs/20230818-Dropout/README.md
Outdated
uint8_t enabled_u8; | ||
error::wrap_c_api( | ||
dnnl_primitive_attr_get_dropout(get(), &enabled_u8, &cdesc), | ||
"could not get parameters of a dropout attribute"); | ||
dnnl_memory_desc_t cloned_md = nullptr; | ||
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), | ||
"could not clone a memory descriptor"); | ||
mask_desc = memory::desc(cloned_md); | ||
enabled = enabled_u8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint8_t enabled_u8; | |
error::wrap_c_api( | |
dnnl_primitive_attr_get_dropout(get(), &enabled_u8, &cdesc), | |
"could not get parameters of a dropout attribute"); | |
dnnl_memory_desc_t cloned_md = nullptr; | |
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), | |
"could not clone a memory descriptor"); | |
mask_desc = memory::desc(cloned_md); | |
enabled = enabled_u8; | |
error::wrap_c_api( | |
dnnl_primitive_attr_get_dropout(get(), &cdesc), | |
"could not get parameters of a dropout attribute"); | |
dnnl_memory_desc_t cloned_md = nullptr; | |
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), | |
"could not clone a memory descriptor"); | |
mask_desc = memory::desc(cloned_md); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, changed
a6ffd4d
to
b19eb59
Compare
b19eb59
to
055515a
Compare
This is proposal to support dropout operation in oneDNN via primitive attribute
Link to rendered document
Link to PoC
Performance data for DGL (with PyTorch backend) GNN training benchmarks using Icelake server machine: