@@ -406,8 +406,10 @@ class InsertGPUAllocsPass final
406406 auto newAlloc = builder.create <mlir::memref::AllocOp>(
407407 loc, alloc.getType (), alloc.getDynamicSizes (),
408408 alloc.getSymbolOperands ());
409- builder.create <mlir::memref::CopyOp>(loc, allocResult,
410- newAlloc.getResult ());
409+ builder.create <mlir::gpu::MemcpyOp>(
410+ loc, /* asyncToken*/ static_cast <mlir::Type>(nullptr ),
411+ /* asyncDependencies*/ std::nullopt , newAlloc.getResult (),
412+ allocResult);
411413 use.set (newAlloc.getResult ());
412414 }
413415 }
@@ -456,8 +458,9 @@ class InsertGPUAllocsPass final
456458 /* symbolOperands*/ std::nullopt , hostShared);
457459 auto allocResult = gpuAlloc.getResult (0 );
458460 if (access.hostWrite && access.deviceRead ) {
459- auto copy =
460- builder.create <mlir::memref::CopyOp>(loc, op, allocResult);
461+ auto copy = builder.create <mlir::gpu::MemcpyOp>(
462+ loc, /* asyncToken*/ static_cast <mlir::Type>(nullptr ),
463+ /* asyncDependencies*/ std::nullopt , allocResult, op);
461464 filter.insert (copy);
462465 }
463466
@@ -476,7 +479,9 @@ class InsertGPUAllocsPass final
476479 op.replaceAllUsesExcept (allocResult, filter);
477480 builder.setInsertionPoint (term);
478481 if (access.hostRead && access.deviceWrite ) {
479- builder.create <mlir::memref::CopyOp>(loc, allocResult, op);
482+ builder.create <mlir::gpu::MemcpyOp>(
483+ loc, /* asyncToken*/ static_cast <mlir::Type>(nullptr ),
484+ /* asyncDependencies*/ std::nullopt , op, allocResult);
480485 }
481486 builder.create <mlir::gpu::DeallocOp>(loc, std::nullopt , allocResult);
482487 }
0 commit comments