Skip to content

Conversation

@chengduoZH
Copy link
Contributor

@chengduoZH chengduoZH commented May 2, 2018

Enabling parallel_exe to support updating sparse parameters.
Different with #10096, which enable parallel_exe to support updating sparse parameters and assigns parameter gradients evenly to different cards for updates.

@chengduoZH chengduoZH requested review from panyx0718 and reyoung May 3, 2018 01:48
}
} else {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(in_tensor.place()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line seems redundant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);

if (*out_var_handle != *in_var_handle) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend as a TODO, we use a named method (e.g. IsSameNameAndVersion()) for VarHandle comparison. Normally comparison operator overload should be used with care.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done.

}

int type = platform::ToNCCLDataType(in_tensor.type());
broadcast_calls.emplace_back([=] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normally we don't use "=", which might accidently copy some big stuff. Explicitly name the vars you need in the lambda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done.

&VariableVisitor::GetMutableTensor(out_var));
});
}
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this else is very long...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the else code has been refactored.

}
});
#else
PADDLE_THROW("CUDA is not support.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA is not enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());

// size_t cur_device_id = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

} else {
CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding) {
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic doesn't belong to this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if this op's inputs don't include sparse gradient, GetOpDeviceID will return -1, this means this op should be executed on all devices.

// size_t cur_device_id = 0;
size_t update_sparse_gp_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparse_var_xxx
bcast_sparse_var_xxx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
#ifdef PADDLE_WITH_CUDA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's discuss tomorrow. Does tensorflow use this many GOOGLE_CUDA? I feel PADDLE_WITH_CUDA is everywhere...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensorflow uses GOOGLE_CUDA in many places too.

}
}

VarHandle *SSAGraphBuilder::GetLatestVarHandle(SSAGraph *graph,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used some where?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this function doesn't belong to shi PR.

@chengduoZH chengduoZH force-pushed the feature/update_sparse_parameter branch from d8ead0f to 7722baa Compare May 4, 2018 08:08
@chengduoZH chengduoZH force-pushed the feature/update_sparse_parameter branch from cff313a to f9c680c Compare May 4, 2018 08:28
panyx0718
panyx0718 previously approved these changes May 4, 2018
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this line if it's not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

call();
}
}
// TODO(zcd): Maybe the unequal operator is not appropriate here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be more specific?

Copy link
Contributor Author

@chengduoZH chengduoZH May 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have refined the logic of broadcast_op, the tensors' Place of input and output must be all on GPU or all on CPU.

// Variable may have many different var_handles, the version_ of these
// var_handles
// is different. So I don't take care of version_ temporarily when overloading
// equal.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a TODO and rename this == operator to something like IsNameAndScopeSame()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to have var not being on all devices now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, just by removing this var from the var's scope.

@chengduoZH chengduoZH force-pushed the feature/update_sparse_parameter branch 2 times, most recently from 9da2bbf to 88be79d Compare May 5, 2018 05:27
@chengduoZH chengduoZH force-pushed the feature/update_sparse_parameter branch from 88be79d to 0441c2c Compare May 5, 2018 05:54
@chengduoZH chengduoZH force-pushed the feature/update_sparse_parameter branch from 7b1c794 to e8ebb91 Compare May 5, 2018 07:44
@chengduoZH chengduoZH force-pushed the feature/update_sparse_parameter branch from e8ebb91 to 881e063 Compare May 5, 2018 07:51
@chengduoZH chengduoZH merged commit 99acf1d into PaddlePaddle:develop May 7, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants