@@ -214,56 +214,34 @@ void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst,
214
214
func (dev_ctx, src, dst);
215
215
}
216
216
217
- template <typename VarType>
218
- std::shared_ptr<pten::DenseTensor> GetInnerDstTensor (VarType* dst) {
219
- PADDLE_THROW (platform::errors::Unavailable (
220
- " GetInnerDstTensor only support egr::EagerTensor or framework::Variable, "
221
- " please check your code and make sure you are using one of them." ));
222
- return nullptr ;
223
- }
224
-
225
- template <typename VarType>
226
- std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor (const VarType& src) {
227
- PADDLE_THROW (platform::errors::Unavailable (
228
- " GetInnerSrcTensor only support egr::EagerTensor or framework::Variable, "
229
- " please check your code and make sure you are using one of them." ));
230
- return nullptr ;
231
- }
232
-
233
- template <>
234
- std::shared_ptr<pten::DenseTensor> GetInnerDstTensor<egr::EagerTensor>(
235
- egr::EagerTensor* dst) {
217
+ std::shared_ptr<pten::DenseTensor> GetInnerDstTensor (egr::EagerTensor* dst) {
236
218
std::shared_ptr<pten::DenseTensor> dst_tensor =
237
219
std::dynamic_pointer_cast<pten::DenseTensor>(dst->impl ());
238
220
return dst_tensor;
239
221
}
240
222
241
- template <>
242
- std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor<egr::EagerTensor>(
223
+ std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor (
243
224
const egr::EagerTensor& src) {
244
225
std::shared_ptr<pten::DenseTensor> dst_tensor =
245
226
std::dynamic_pointer_cast<pten::DenseTensor>(src.impl ());
246
227
return dst_tensor;
247
228
}
248
229
249
- template <>
250
- std::shared_ptr<pten::DenseTensor> GetInnerDstTensor<framework::Variable>(
251
- framework::Variable* dst) {
230
+ std::shared_ptr<pten::DenseTensor> GetInnerDstTensor (framework::Variable* dst) {
252
231
auto * dst_tensor = dst->GetMutable <framework::LoDTensor>();
253
232
return std::make_shared<pten::DenseTensor>(*dst_tensor);
254
233
}
255
234
256
- template <>
257
- std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor<framework::Variable>(
235
+ std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor (
258
236
const framework::Variable& src) {
259
237
auto & src_tensor = src.Get <framework::LoDTensor>();
260
238
return std::make_shared<pten::DenseTensor>(src_tensor);
261
239
}
262
240
263
241
template <typename VarType>
264
242
void TensorAdd (const VarType& src, VarType* dst) {
265
- std::shared_ptr<pten::DenseTensor> d_tensor = GetInnerDstTensor<VarType> (dst);
266
- std::shared_ptr<pten::DenseTensor> s_tensor = GetInnerSrcTensor<VarType> (src);
243
+ std::shared_ptr<pten::DenseTensor> d_tensor = GetInnerDstTensor (dst);
244
+ std::shared_ptr<pten::DenseTensor> s_tensor = GetInnerSrcTensor (src);
267
245
268
246
auto * dst_tensor = d_tensor.get ();
269
247
auto & src_tensor = *s_tensor.get ();
0 commit comments