Skip to content

Commit cbc99da

Browse files
committed
Use overload instead of template
1 parent c444e83 commit cbc99da

File tree

1 file changed

+6
-28
lines changed

1 file changed

+6
-28
lines changed

paddle/fluid/imperative/gradient_accumulator.cc

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -214,56 +214,34 @@ void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst,
214214
func(dev_ctx, src, dst);
215215
}
216216

217-
template <typename VarType>
218-
std::shared_ptr<pten::DenseTensor> GetInnerDstTensor(VarType* dst) {
219-
PADDLE_THROW(platform::errors::Unavailable(
220-
"GetInnerDstTensor only support egr::EagerTensor or framework::Variable, "
221-
"please check your code and make sure you are using one of them."));
222-
return nullptr;
223-
}
224-
225-
template <typename VarType>
226-
std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor(const VarType& src) {
227-
PADDLE_THROW(platform::errors::Unavailable(
228-
"GetInnerSrcTensor only support egr::EagerTensor or framework::Variable, "
229-
"please check your code and make sure you are using one of them."));
230-
return nullptr;
231-
}
232-
233-
template <>
234-
std::shared_ptr<pten::DenseTensor> GetInnerDstTensor<egr::EagerTensor>(
235-
egr::EagerTensor* dst) {
217+
std::shared_ptr<pten::DenseTensor> GetInnerDstTensor(egr::EagerTensor* dst) {
236218
std::shared_ptr<pten::DenseTensor> dst_tensor =
237219
std::dynamic_pointer_cast<pten::DenseTensor>(dst->impl());
238220
return dst_tensor;
239221
}
240222

241-
template <>
242-
std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor<egr::EagerTensor>(
223+
std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor(
243224
const egr::EagerTensor& src) {
244225
std::shared_ptr<pten::DenseTensor> dst_tensor =
245226
std::dynamic_pointer_cast<pten::DenseTensor>(src.impl());
246227
return dst_tensor;
247228
}
248229

249-
template <>
250-
std::shared_ptr<pten::DenseTensor> GetInnerDstTensor<framework::Variable>(
251-
framework::Variable* dst) {
230+
std::shared_ptr<pten::DenseTensor> GetInnerDstTensor(framework::Variable* dst) {
252231
auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
253232
return std::make_shared<pten::DenseTensor>(*dst_tensor);
254233
}
255234

256-
template <>
257-
std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor<framework::Variable>(
235+
std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor(
258236
const framework::Variable& src) {
259237
auto& src_tensor = src.Get<framework::LoDTensor>();
260238
return std::make_shared<pten::DenseTensor>(src_tensor);
261239
}
262240

263241
template <typename VarType>
264242
void TensorAdd(const VarType& src, VarType* dst) {
265-
std::shared_ptr<pten::DenseTensor> d_tensor = GetInnerDstTensor<VarType>(dst);
266-
std::shared_ptr<pten::DenseTensor> s_tensor = GetInnerSrcTensor<VarType>(src);
243+
std::shared_ptr<pten::DenseTensor> d_tensor = GetInnerDstTensor(dst);
244+
std::shared_ptr<pten::DenseTensor> s_tensor = GetInnerSrcTensor(src);
267245

268246
auto* dst_tensor = d_tensor.get();
269247
auto& src_tensor = *s_tensor.get();

0 commit comments

Comments
 (0)