File tree Expand file tree Collapse file tree 3 files changed +17
-1
lines changed Expand file tree Collapse file tree 3 files changed +17
-1
lines changed Original file line number Diff line number Diff line change @@ -1893,6 +1893,12 @@ const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
1893
1893
}
1894
1894
return m->GetAllocator (place, /* A non-zero num to choose allocator_ */ 1 );
1895
1895
}
1896
+ void AllocatorFacade::SetDefaultStream (const phi::XPUPlace& place,
1897
+ XPUStream stream) {
1898
+ if (m_->IsStreamSafeCUDAAllocatorUsed ()) {
1899
+ m_->SetDefaultStream (place, stream);
1900
+ }
1901
+ }
1896
1902
#endif
1897
1903
1898
1904
#ifdef PADDLE_WITH_CUSTOM_DEVICE
Original file line number Diff line number Diff line change @@ -97,6 +97,7 @@ class AllocatorFacade {
97
97
#elif defined(PADDLE_WITH_XPU)
98
98
TEST_API const std::shared_ptr<Allocator>& GetAllocator (
99
99
const phi::Place& place, XPUStream stream);
100
+ void SetDefaultStream (const phi::XPUPlace& place, XPUStream stream);
100
101
#endif
101
102
102
103
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Original file line number Diff line number Diff line change @@ -33,6 +33,10 @@ limitations under the License. */
33
33
#include " paddle/phi/core/platform/cuda_device_guard.h"
34
34
#endif
35
35
36
+ #if defined(PADDLE_WITH_XPU)
37
+ #include " paddle/phi/backends/xpu/xpu_context.h"
38
+ #endif
39
+
36
40
namespace paddle {
37
41
namespace platform {
38
42
@@ -105,7 +109,12 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext(
105
109
#endif
106
110
} else if (p.GetType () == phi::AllocationType::XPU) {
107
111
#if defined(PADDLE_WITH_XPU)
108
- dev_ctx->SetAllocator (instance.GetAllocator (p).get ());
112
+ auto * xpu_ctx = dynamic_cast <phi::XPUContext*>(dev_ctx);
113
+ if (!disable_setting_default_stream_for_allocator) {
114
+ instance.SetDefaultStream (phi::XPUPlace (p.GetDeviceId ()),
115
+ xpu_ctx->stream ());
116
+ }
117
+ dev_ctx->SetAllocator (instance.GetAllocator (p, xpu_ctx->stream ()).get ());
109
118
dev_ctx->SetGenerator (phi::DefaultXPUGenerator (p.GetDeviceId ()).get ());
110
119
#endif
111
120
#ifdef PADDLE_WITH_CUSTOM_DEVICE
You can’t perform that action at this time.
0 commit comments