Skip to content

Commit df82fd3

Browse files
authored
[BugFix]Fix OneDNN Kernels Bug when use pass (#48364)
* Fix onednn kernel bugs * fix gpu bugs
1 parent b4b926f commit df82fd3

File tree

5 files changed

+34
-0
lines changed

5 files changed

+34
-0
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3233,6 +3233,29 @@ void OperatorWithKernel::BuildPhiKernelContext(
32333233
}
32343234
VLOG(4) << "Done attributes";
32353235

3236+
// Clear All old attrs before add new attrs,
3237+
// because sometimes old attrs may be misused.
3238+
#if defined(PADDLE_WITH_MKLDNN)
3239+
if (phi::OneDNNContext::classof(dev_ctx)) {
3240+
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
3241+
one_dnn_ctx->ClearDnnAttr();
3242+
}
3243+
#endif
3244+
3245+
// Note(YuanRisheng): Now, we can't open code below.
3246+
// Because some unittest run OLD dygraph and ExtraAttr is not supported in OLD
3247+
// dygraph. So, here we use trick that dev_ctx is a global object. We can
3248+
// store ExtraAttr in static graph and when unittest run OLD dygraph, it can
3249+
// obtain these ExtraAttr. We can open this code when OLD dygraph is no longer
3250+
// used.
3251+
/*
3252+
#if defined(PADDLE_WITH_CUDA)
3253+
if(phi::GPUContext::classof(dev_ctx)) {
3254+
phi::GPUContext* gpu_dnn_ctx = static_cast<phi::GPUContext*>(dev_ctx);
3255+
gpu_dnn_ctx->ClearDnnAttr();
3256+
}
3257+
#endif
3258+
*/
32363259
// For compatible with Op with extra attrs for specific backend
32373260
#if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA)
32383261
auto& runtime_attrs = RuntimeAttrs();

paddle/phi/backends/gpu/gpu_context.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,8 @@ struct GPUContext::Impl {
740740
dnn_attrs_[attr_name] = attr;
741741
}
742742

743+
void ClearDnnAttr() { dnn_attrs_.clear(); }
744+
743745
// use one flag for all handles?
744746
// they should be accessed consistently
745747
bool owned_{false};
@@ -1042,4 +1044,6 @@ void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
10421044
return impl_->SetDnnAttr(attr_name, std::move(attr));
10431045
}
10441046

1047+
void GPUContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }
1048+
10451049
} // namespace phi

paddle/phi/backends/gpu/gpu_context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ class PADDLE_API GPUContext : public DeviceContext,
172172
bool HasDnnAttr(const std::string& attr_name) const;
173173
const Attribute& GetDnnAttr(const std::string& attr_name) const;
174174
void SetDnnAttr(const std::string& attr_name, Attribute attr);
175+
void ClearDnnAttr();
175176

176177
static const char* name() { return "GPUContext"; }
177178

paddle/phi/backends/onednn/onednn_context.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ struct OneDNNContext::Impl {
301301
dnn_attrs_[attr_name] = attr;
302302
}
303303

304+
void ClearDnnAttr() { dnn_attrs_.clear(); }
305+
304306
bool HasDnnInput(const std::string& input_name) const {
305307
return dnn_inputs_.count(input_name) != 0UL;
306308
}
@@ -425,6 +427,8 @@ void OneDNNContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
425427
return impl_->SetDnnAttr(attr_name, std::move(attr));
426428
}
427429

430+
void OneDNNContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }
431+
428432
bool OneDNNContext::HasDnnInput(const std::string& input_name) const {
429433
return impl_->HasDnnInput(input_name);
430434
}

paddle/phi/backends/onednn/onednn_context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ class OneDNNContext : public CPUContext {
146146
const DenseTensor* GetDnnInput(const std::string& input_name) const;
147147
void SetDnnInput(const std::string& input_name, const DenseTensor* input);
148148

149+
void ClearDnnAttr();
150+
149151
void SetInputsName(const TensorNameMap& inputs_name);
150152

151153
void SetOutputsName(const TensorNameMap& outputs_name);

0 commit comments

Comments
 (0)