Skip to content
6 changes: 4 additions & 2 deletions backends/metax_gpu/patch/tmp/mixed_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ class MixVector {

// the unify method to access CPU or CUDA data. immutable.
const T *Data(phi::Place place) const {
if (place.GetType() == phi::AllocationType::GPU) {
if (place.GetType() == phi::AllocationType::GPU ||
place.GetType() == phi::AllocationType::CUSTOM) {
return CUDAData(place);
} else {
return data();
Expand All @@ -395,7 +396,8 @@ class MixVector {

// the unify method to access CPU or CUDA data. mutable.
T *MutableData(phi::Place place) {
if (place.GetType() == phi::AllocationType::GPU) {
if (place.GetType() == phi::AllocationType::GPU ||
place.GetType() == phi::AllocationType::CUSTOM) {
return CUDAMutableData(place);
} else {
return data();
Expand Down