Skip to content

[Eager] Refactor TensorAdd by template #39282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/eager/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor legacy autograd_meta grad_node_info grad_tensor_holder gradient_accumulation accumulation_node)
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor legacy autograd_meta grad_node_info grad_tensor_holder accumulation_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
set(generated_deps dygraph_function dygraph_node)

Expand All @@ -12,7 +12,7 @@ add_subdirectory(accumulation)
add_subdirectory(legacy)

cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)

cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)
cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils)
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/eager/accumulation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
cc_library(gradient_accumulation SRCS gradient_accumulation.cc DEPS blas pten pten_api var_type_traits layer math_function)
cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulation pten pten_api grad_node_info)
cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulator pten pten_api grad_node_info)
4 changes: 2 additions & 2 deletions paddle/fluid/eager/accumulation/accumulation_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"

#include "paddle/pten/api/all.h"
#include "paddle/pten/core/dense_tensor.h"
Expand All @@ -35,7 +35,7 @@ static void CopyOrAddTensor(egr::EagerTensor* tensor,
*tensor = t;
} else {
// Accumulation
egr::TensorAdd(t, tensor);
paddle::imperative::TensorAdd<egr::EagerTensor>(t, tensor);
}
}

Expand Down
291 changes: 0 additions & 291 deletions paddle/fluid/eager/accumulation/gradient_accumulation.cc

This file was deleted.

23 changes: 0 additions & 23 deletions paddle/fluid/eager/accumulation/gradient_accumulation.h

This file was deleted.

10 changes: 5 additions & 5 deletions paddle/fluid/eager/grad_tensor_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"

#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/math/math_function.h"
Expand Down Expand Up @@ -72,17 +72,17 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
} else {
// Accumulation
if (t.initialized() && buffer_tensor.initialized()) {
TensorAdd(t, &buffer_tensor);
paddle::imperative::TensorAdd<egr::EagerTensor>(t, &buffer_tensor);
} else if (t.Var().IsInitialized() &&
buffer_tensor.Var().IsInitialized()) {
VariableAdd(t, &buffer_tensor);
paddle::imperative::VariableAdd(t, &buffer_tensor);
} else if (t.Var().IsInitialized() && buffer_tensor.initialized()) {
// TODO(jiabin): This can be merge to upper if case.
buffer_tensor.SyncToVar();
VariableAdd(t, &buffer_tensor);
paddle::imperative::VariableAdd(t, &buffer_tensor);
} else if (t.initialized() && buffer_tensor.Var().IsInitialized()) {
buffer_tensor.SyncToTensor();
TensorAdd(t, &buffer_tensor);
paddle::imperative::TensorAdd<egr::EagerTensor>(t, &buffer_tensor);
} else {
// Should not happend case
// 1. both not init
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/imperative/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
cc_library(imperative_flag SRCS flags.cc DEPS gflags flags)
IF(WITH_XPU)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils pten_api)
ELSE()
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils pten_api)
ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
add_subdirectory(jit)
Expand Down
Loading