Skip to content

Conversation

clementval
Copy link
Contributor

When a field in a derived type is c_devptr, keep check if we can do a memcpy instead of falling back to the runtime assignment.

Many internal CUDA Fortran derived type have a c_devptr field and this would lead to stack overflow on the device if the assignment is performed by the runtime function.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jan 24, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 24, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

When a field in a derived type is c_devptr, keep check if we can do a memcpy instead of falling back to the runtime assignment.

Many internal CUDA Fortran derived type have a c_devptr field and this would lead to stack overflow on the device if the assignment is performed by the runtime function.


Full diff: https://github.com/llvm/llvm-project/pull/124322.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+2-1)
  • (modified) flang/test/Lower/CUDA/cuda-devptr.cuf (+19)
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 64c540cfb95ae6..35dc9a2abd69c4 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1410,7 +1410,8 @@ static bool recordTypeCanBeMemCopied(fir::RecordType recordType) {
   for (auto [_, fieldType] : recordType.getTypeList()) {
     // Derived type component may have user assignment (so far, we cannot tell
     // in FIR, so assume it is always the case, TODO: get the actual info).
-    if (mlir::isa<fir::RecordType>(fir::unwrapSequenceType(fieldType)))
+    if (mlir::isa<fir::RecordType>(fir::unwrapSequenceType(fieldType)) &&
+        !fir::isa_builtin_c_devptr_type(fir::unwrapSequenceType(fieldType)))
       return false;
     // Allocatable components need deep copy.
     if (auto boxType = mlir::dyn_cast<fir::BaseBoxType>(fieldType))
diff --git a/flang/test/Lower/CUDA/cuda-devptr.cuf b/flang/test/Lower/CUDA/cuda-devptr.cuf
index d61d84d9bc750f..0a9087cf6c1334 100644
--- a/flang/test/Lower/CUDA/cuda-devptr.cuf
+++ b/flang/test/Lower/CUDA/cuda-devptr.cuf
@@ -4,6 +4,12 @@
 
 module cudafct
   use __fortran_builtins, only : c_devptr => __builtin_c_devptr
+  
+  type :: t1
+    type(c_devptr) :: devp
+    integer :: a
+  end type
+
 contains
   function c_devloc(x)
     use iso_c_binding, only: c_loc
@@ -12,6 +18,10 @@ contains
     real, target, device :: x
     c_devloc%cptr = c_loc(x)
   end function
+
+  attributes(device) function get_t1()
+    type(t1) :: get_t1
+  end
 end
 
 subroutine sub1()
@@ -68,3 +78,12 @@ end subroutine
 ! CHECK: %[[P_ADDR_COORD:.*]] = fir.coordinate_of %[[P_CPTR_COORD]], %[[ADDRESS_FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
 ! CHECK: %[[ADDR:.*]] = fir.load %[[RES_ADDR_COORD]] : !fir.ref<i64>
 ! CHECK: fir.store %[[ADDR]] to %[[P_ADDR_COORD]] : !fir.ref<i64>
+
+attributes(global) subroutine assign_nested_c_devptr(p, a)
+  use cudafct
+  type(t1), device :: p
+  p = get_t1()
+end subroutine
+
+! CHECK-LABEL: func.func @_QPassign_nested_c_devptr
+! CHECK-NOT: fir.call @_FortranAAssign

@clementval clementval merged commit 05fd4d5 into llvm:main Jan 24, 2025
11 checks passed
@clementval clementval deleted the cuf_devptr_nested branch January 24, 2025 22:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants