@@ -56,6 +56,78 @@ static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
5656 }
5757}
5858
59+ void IncreaseVarbaseReferenceCountUntilCopyComplete (
60+ const std::shared_ptr<imperative::VarBase>& var,
61+ const platform::Place& place) {
62+ // Note(zhiqiu): Follow the logic of TensorCopy to determine the place that we
63+ // need to add callback, see tensor_utils.cc:245
64+ auto place_ = platform::is_gpu_place (place) ? place : var->Place ();
65+
66+ auto tracer = imperative::GetCurrentTracer ();
67+ auto gc = tracer->MutableGarbageCollectorIfNotExists (place_);
68+
69+ // Note(zhiqiu): This is an empty callback, the only way is to "reference"
70+ // var, so it will not be destructed until the kernels launched at current
71+ // stream of given place is finished.
72+ auto callback = [var, place_]() {
73+ VLOG (4 ) << " Run callback of var:" << var->Name () << " at place " << place_;
74+ };
75+
76+ gc->DirectClearCallback (callback);
77+ }
78+
79+ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists (
80+ const platform::Place& place) {
81+ // if not exists, create a new GarbageCollector at given place
82+ if (gcs_.count (place) == 0 ) {
83+ std::unique_ptr<framework::GarbageCollector> gc;
84+ if (platform::is_gpu_place (place)) {
85+ #ifdef PADDLE_WITH_CUDA
86+ gc.reset (new framework::DefaultStreamGarbageCollector (
87+ BOOST_GET_CONST (platform::CUDAPlace, place), 0 ));
88+
89+ VLOG (10 ) << " Created GarbageCollector at " << place;
90+ #else
91+ PADDLE_THROW (platform::errors::PermissionDenied (
92+ " Paddle can't use CUDA device since it's not compiled with CUDA,"
93+ " Please recompile or reinstall Paddle with GPU support." ));
94+ #endif
95+ } else if (platform::is_cuda_pinned_place (place)) {
96+ #ifdef PADDLE_WITH_CUDA
97+ gc.reset (new framework::CUDAPinnedGarbageCollector (
98+ BOOST_GET_CONST (platform::CUDAPinnedPlace, place), 0 ));
99+
100+ VLOG (10 ) << " Created GarbageCollector at " << place;
101+ #else
102+ PADDLE_THROW (platform::errors::PermissionDenied (
103+ " Paddle can't use CUDAPinned device since it's not compiled with "
104+ " CUDA,"
105+ " Please recompile or reinstall Paddle with GPU support." ));
106+ #endif
107+ } else if (platform::is_xpu_place (place)) {
108+ #if defined(PADDLE_WITH_XPU)
109+ gc.reset (new framework::XPUGarbageCollector (
110+ BOOST_GET_CONST (platform::XPUPlace, place), 0 ));
111+ VLOG (10 ) << " Created GarbageCollector at " << place;
112+ #else
113+ PADDLE_THROW (platform::errors::PermissionDenied (
114+ " Paddle can't use XPU device since it's not compiled with XPU,"
115+ " Please recompile or reinstall Paddle with XPU support." ));
116+ #endif
117+ } else if (platform::is_cpu_place (place)) {
118+ gc.reset (new framework::CPUGarbageCollector (
119+ BOOST_GET_CONST (platform::CPUPlace, place), 0 ));
120+ VLOG (10 ) << " Created GarbageCollector at " << place;
121+ } else {
122+ PADDLE_THROW (platform::errors::PreconditionNotMet (
123+ " Unsupported place for garbage collection" ));
124+ }
125+ gcs_.emplace (place, std::move (gc));
126+ }
127+
128+ return gcs_.at (place).get ();
129+ }
130+
59131void Tracer::TraceOp (const std::string& type, const NameVarBaseMap& ins,
60132 const NameVarBaseMap& outs, framework::AttributeMap attrs,
61133 const platform::Place& place, bool trace_backward) {
0 commit comments