Skip to content

Commit

Permalink
[flang][cuda] Handle pointer allocation with double descriptors (#124183
Browse files Browse the repository at this point in the history
)
  • Loading branch information
clementval authored Jan 23, 2025
1 parent 13dae34 commit 67a8857
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 14 deletions.
13 changes: 13 additions & 0 deletions flang/include/flang/Runtime/CUDA/pointer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,26 @@ int RTDECL(CUFPointerAllocate)(Descriptor &, int64_t stream = -1,
bool hasStat = false, const Descriptor *errMsg = nullptr,
const char *sourceFile = nullptr, int sourceLine = 0);

/// Perform allocation of the descriptor with synchronization of it when
/// necessary.
int RTDECL(CUFPointerAllocateSync)(Descriptor &, int64_t stream = -1,
bool hasStat = false, const Descriptor *errMsg = nullptr,
const char *sourceFile = nullptr, int sourceLine = 0);

/// Perform allocation of the descriptor without synchronization. Assign data
/// from source.
int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
const Descriptor &source, int64_t stream = -1, bool hasStat = false,
const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
int sourceLine = 0);

/// Perform allocation of the descriptor with synchronization of it when
/// necessary. Assign data from source.
int RTDEF(CUFPointerAllocateSourceSync)(Descriptor &pointer,
const Descriptor &source, int64_t stream = -1, bool hasStat = false,
const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
int sourceLine = 0);

} // extern "C"

} // namespace Fortran::runtime::cuda
Expand Down
20 changes: 12 additions & 8 deletions flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,22 @@ struct CUFAllocateOpConversion
isPointer = true;

if (hasDoubleDescriptors(op)) {
if (isPointer)
TODO(loc, "pointer allocation with double descriptors");
// Allocation for module variable are done with custom runtime entry point
// so the descriptors can be synchronized.
mlir::func::FuncOp func;
if (op.getSource())
func = fir::runtime::getRuntimeFunc<mkRTKey(
CUFAllocatableAllocateSourceSync)>(loc, builder);
else
if (op.getSource()) {
func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
CUFPointerAllocateSourceSync)>(loc, builder)
: fir::runtime::getRuntimeFunc<mkRTKey(
CUFAllocatableAllocateSourceSync)>(loc, builder);
} else {
func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocateSync)>(
loc, builder);
isPointer
? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
loc, builder)
: fir::runtime::getRuntimeFunc<mkRTKey(
CUFAllocatableAllocateSync)>(loc, builder);
}
return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
}

Expand Down
32 changes: 32 additions & 0 deletions flang/runtime/CUDA/pointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "../assign-impl.h"
#include "../stat.h"
#include "../terminator.h"
#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/CUDA/memmove-function.h"
#include "flang/Runtime/pointer.h"

Expand All @@ -35,6 +36,24 @@ int RTDEF(CUFPointerAllocate)(Descriptor &desc, int64_t stream, bool hasStat,
return stat;
}

int RTDEF(CUFPointerAllocateSync)(Descriptor &desc, int64_t stream,
bool hasStat, const Descriptor *errMsg, const char *sourceFile,
int sourceLine) {
int stat{RTNAME(CUFPointerAllocate)(
desc, stream, hasStat, errMsg, sourceFile, sourceLine)};
#ifndef RT_DEVICE_COMPILATION
// Descriptor synchronization is only done when the allocation is done
// from the host.
if (stat == StatOk) {
void *deviceAddr{
RTNAME(CUFGetDeviceAddress)((void *)&desc, sourceFile, sourceLine)};
RTNAME(CUFDescriptorSync)
((Descriptor *)deviceAddr, &desc, sourceFile, sourceLine);
}
#endif
return stat;
}

int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
const Descriptor &source, int64_t stream, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Expand All @@ -48,6 +67,19 @@ int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
return stat;
}

int RTDEF(CUFPointerAllocateSourceSync)(Descriptor &pointer,
const Descriptor &source, int64_t stream, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
int stat{RTNAME(CUFPointerAllocateSync)(
pointer, stream, hasStat, errMsg, sourceFile, sourceLine)};
if (stat == StatOk) {
Terminator terminator{sourceFile, sourceLine};
Fortran::runtime::DoFromSourceAssign(
pointer, source, terminator, &MemmoveHostToDevice);
}
return stat;
}

RT_EXT_API_GROUP_END

} // extern "C"
Expand Down
57 changes: 51 additions & 6 deletions flang/test/Fir/CUDA/cuda-allocate.fir
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,61 @@ func.func @_QPpointer_source() {
%c0_i32 = arith.constant 0 : i32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>> {bindc_name = "a", uniq_name = "_QFpointer_sourceEa"}
%4 = fir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFpointer_sourceEa"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
%5 = cuf.alloc !fir.box<!fir.heap<!fir.array<?x?xf32>>> {bindc_name = "a_d", data_attr = #cuf.cuda<device>, uniq_name = "_QFpointer_sourceEa_d"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
%7 = fir.declare %5 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFpointer_sourceEa_d"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
%8 = fir.load %4 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
%22 = cuf.allocate %7 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>> source(%8 : !fir.box<!fir.heap<!fir.array<?x?xf32>>>) {data_attr = #cuf.cuda<device>} -> i32
%0 = fir.alloca !fir.box<!fir.ptr<!fir.array<?x?xf32>>> {bindc_name = "a", uniq_name = "_QFpointer_sourceEa"}
%4 = fir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFpointer_sourceEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
%5 = cuf.alloc !fir.box<!fir.ptr<!fir.array<?x?xf32>>> {bindc_name = "a_d", data_attr = #cuf.cuda<device>, uniq_name = "_QFpointer_sourceEa_d"} -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
%7 = fir.declare %5 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFpointer_sourceEa_d"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
%8 = fir.load %4 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
%22 = cuf.allocate %7 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> source(%8 : !fir.box<!fir.ptr<!fir.array<?x?xf32>>>) {data_attr = #cuf.cuda<device>} -> i32
return
}

// CHECK-LABEL: func.func @_QPpointer_source()
// CHECK: _FortranACUFPointerAllocateSource

fir.global @_QMdataEb2 {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xi32>>> {
%c0 = arith.constant 0 : index
%0 = fir.zero_bits !fir.ptr<!fir.array<?xi32>>
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xi32>>>
fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xi32>>>
}

func.func @_QQpointer_sync() attributes {fir.bindc_name = "test"} {
%c0_i32 = arith.constant 0 : i32
%c10_i32 = arith.constant 10 : i32
%c1 = arith.constant 1 : index
%0 = fir.address_of(@_QMdataEb2) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
%1 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdataEb"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>)
%2 = fir.convert %1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
%3 = fir.convert %c1 : (index) -> i64
%4 = fir.convert %c10_i32 : (i32) -> i64
fir.call @_FortranAAllocatableSetBounds(%2, %c0_i32, %3, %4) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> ()
%6 = cuf.allocate %1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>} -> i32
return
}

// CHECK-LABEL: func.func @_QQpointer_sync()
// CHECK: _FortranACUFPointerAllocateSync

fir.global @_QMmod1Ea_d2 {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?x?xf32>>> {
%c0 = arith.constant 0 : index
%0 = fir.zero_bits !fir.ptr<!fir.array<?x?xf32>>
%1 = fir.shape %c0, %c0 : (index, index) -> !fir.shape<2>
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
}
func.func @_QMmod1Ppointer_source_global() {
%0 = fir.address_of(@_QMmod1Ea_d2) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
%1 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMmod1Ea_d"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
%2 = fir.alloca !fir.box<!fir.ptr<!fir.array<?x?xf32>>> {bindc_name = "a", uniq_name = "_QMmod1Fallocate_source_globalEa"}
%6 = fir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Fallocate_source_globalEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
%7 = fir.load %6 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
%21 = cuf.allocate %1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> source(%7 : !fir.box<!fir.ptr<!fir.array<?x?xf32>>>) {data_attr = #cuf.cuda<device>} -> i32
return
}

// CHECK-LABEL: func.func @_QMmod1Ppointer_source_global()
// CHECK: fir.call @_FortranACUFPointerAllocateSourceSync

} // end of module

0 comments on commit 67a8857

Please sign in to comment.