-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Feature/update sparse parameter #10351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/update sparse parameter #10351
Conversation
| } | ||
| } else { | ||
| #ifdef PADDLE_WITH_CUDA | ||
| PADDLE_ENFORCE(platform::is_gpu_place(in_tensor.place())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line seems redundant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
| Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) | ||
| ->FindVar(out_var_handle->name_); | ||
|
|
||
| if (*out_var_handle != *in_var_handle) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend as a TODO, we use a named method (e.g. IsSameNameAndVersion()) for VarHandle comparison. Normally comparison operator overload should be used with care.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done.
| } | ||
|
|
||
| int type = platform::ToNCCLDataType(in_tensor.type()); | ||
| broadcast_calls.emplace_back([=] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normally we don't use "=", which might accidently copy some big stuff. Explicitly name the vars you need in the lambda?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done.
| &VariableVisitor::GetMutableTensor(out_var)); | ||
| }); | ||
| } | ||
| } else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this else is very long...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the else code has been refactored.
| } | ||
| }); | ||
| #else | ||
| PADDLE_THROW("CUDA is not support."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA is not enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( | ||
| places_.size()); | ||
|
|
||
| // size_t cur_device_id = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| } else { | ||
| CreateComputationalOps(&result, *op, places_.size()); | ||
| if (!is_forwarding) { | ||
| int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic doesn't belong to this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, if this op's inputs don't include sparse gradient, GetOpDeviceID will return -1, this means this op should be executed on all devices.
| // size_t cur_device_id = 0; | ||
| size_t update_sparse_gp_device_id = 0; | ||
| std::vector<std::unordered_set<std::string>> var_name_on_devices; | ||
| std::vector<std::unordered_set<std::string>> bcast_var_name_set; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sparse_var_xxx
bcast_sparse_var_xxx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
|
||
| class MultiDevSSAGraphBuilder : public SSAGraphBuilder { | ||
| public: | ||
| #ifdef PADDLE_WITH_CUDA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's discuss tomorrow. Does tensorflow use this many GOOGLE_CUDA? I feel PADDLE_WITH_CUDA is everywhere...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensorflow uses GOOGLE_CUDA in many places too.
| } | ||
| } | ||
|
|
||
| VarHandle *SSAGraphBuilder::GetLatestVarHandle(SSAGraph *graph, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used some where?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this function doesn't belong to shi PR.
d8ead0f to
7722baa
Compare
… feature/update_sparse_parameter
cff313a to
f9c680c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete this line if it's not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| call(); | ||
| } | ||
| } | ||
| // TODO(zcd): Maybe the unequal operator is not appropriate here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
be more specific?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have refined the logic of broadcast_op, the tensors' Place of input and output must be all on GPU or all on CPU.
| // Variable may have many different var_handles, the version_ of these | ||
| // var_handles | ||
| // is different. So I don't take care of version_ temporarily when overloading | ||
| // equal. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a TODO and rename this == operator to something like IsNameAndScopeSame()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| if (op_dev_id == -1) { // var on all device | ||
| CreateComputationalOps(&result, *op, places_.size()); | ||
| } else { | ||
| CreateComputationalOp(&result, *op, op_dev_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to have var not being on all devices now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course, just by removing this var from the var's scope.
9da2bbf to
88be79d
Compare
88be79d to
0441c2c
Compare
7b1c794 to
e8ebb91
Compare
e8ebb91 to
881e063
Compare
Enabling parallel_exe to support updating sparse parameters.
Different with #10096, which enable parallel_exe to support updating sparse parameters and assigns parameter gradients evenly to different cards for updates.