Skip to content

Commit c0fced7

Browse files
authored
[XPU] do not use null stream when CDNN_CLUSTER_PARALLEL is ON (#67276)
1 parent d688d3f commit c0fced7

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

paddle/phi/core/memory/allocation/allocator_facade.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,6 +1893,12 @@ const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
18931893
}
18941894
return m->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1);
18951895
}
1896+
void AllocatorFacade::SetDefaultStream(const phi::XPUPlace& place,
1897+
XPUStream stream) {
1898+
if (m_->IsStreamSafeCUDAAllocatorUsed()) {
1899+
m_->SetDefaultStream(place, stream);
1900+
}
1901+
}
18961902
#endif
18971903

18981904
#ifdef PADDLE_WITH_CUSTOM_DEVICE

paddle/phi/core/memory/allocation/allocator_facade.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class AllocatorFacade {
9797
#elif defined(PADDLE_WITH_XPU)
9898
TEST_API const std::shared_ptr<Allocator>& GetAllocator(
9999
const phi::Place& place, XPUStream stream);
100+
void SetDefaultStream(const phi::XPUPlace& place, XPUStream stream);
100101
#endif
101102

102103
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

paddle/phi/core/platform/device_context.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ limitations under the License. */
3333
#include "paddle/phi/core/platform/cuda_device_guard.h"
3434
#endif
3535

36+
#if defined(PADDLE_WITH_XPU)
37+
#include "paddle/phi/backends/xpu/xpu_context.h"
38+
#endif
39+
3640
namespace paddle {
3741
namespace platform {
3842

@@ -105,7 +109,12 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext(
105109
#endif
106110
} else if (p.GetType() == phi::AllocationType::XPU) {
107111
#if defined(PADDLE_WITH_XPU)
108-
dev_ctx->SetAllocator(instance.GetAllocator(p).get());
112+
auto* xpu_ctx = dynamic_cast<phi::XPUContext*>(dev_ctx);
113+
if (!disable_setting_default_stream_for_allocator) {
114+
instance.SetDefaultStream(phi::XPUPlace(p.GetDeviceId()),
115+
xpu_ctx->stream());
116+
}
117+
dev_ctx->SetAllocator(instance.GetAllocator(p, xpu_ctx->stream()).get());
109118
dev_ctx->SetGenerator(phi::DefaultXPUGenerator(p.GetDeviceId()).get());
110119
#endif
111120
#ifdef PADDLE_WITH_CUSTOM_DEVICE

0 commit comments

Comments
 (0)