Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Make argument passed by value for sync functions #125909

Merged
merged 1 commit into from
Feb 5, 2025

Conversation

clementval
Copy link
Contributor

@clementval clementval commented Feb 5, 2025

syncthreads_and, syncthreads_count, syncthreads_or, synwrap must take their argument by value. This patch updates the interfaces and makes sure these functions can be called inside a cuff kernel as well.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 5, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 5, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

syncthreads_and, syncthreads_count, syncthreads_or, synwrap must take their argument by value. This patch updates the interfaces and makes sure these functions can be called inside a cuff kernel as well.


Full diff: https://github.com/llvm/llvm-project/pull/125909.diff

2 Files Affected:

  • (modified) flang/module/cudadevice.f90 (+4-4)
  • (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+24-10)
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 00e8b3db73ad87..1fe99b30b1db08 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -29,28 +29,28 @@ attributes(device) subroutine syncthreads()
 
   interface
     attributes(device) integer function syncthreads_and(value)
-      integer :: value
+      integer, value :: value
     end function
   end interface
   public :: syncthreads_and
 
   interface
     attributes(device) integer function syncthreads_count(value)
-      integer :: value
+      integer, value :: value
     end function
   end interface
   public :: syncthreads_count
 
   interface
     attributes(device) integer function syncthreads_or(value)
-      integer :: value
+      integer, value :: value
     end function
   end interface
   public :: syncthreads_or
 
   interface
     attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
-      integer :: mask
+      integer, value :: mask
     end subroutine
   end interface
   public :: syncwarp
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 5805dd5010a842..ec825263474c1e 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -47,7 +47,7 @@ end
 
 ! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
 ! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
-! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
+! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -79,17 +79,9 @@ end
 ! CHECK: %{{.*}} = llvm.atomicrmw uinc_wrap  %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
 ! CHECK: %{{.*}} = llvm.atomicrmw udec_wrap  %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
 
-! CHECK: func.func private @llvm.nvvm.barrier0()
-! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
-! CHECK: func.func private @llvm.nvvm.membar.gl()
-! CHECK: func.func private @llvm.nvvm.membar.cta()
-! CHECK: func.func private @llvm.nvvm.membar.sys()
-! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
-
 subroutine host1()
   integer, device :: a(32)
+  integer, device :: ret
   integer :: i, j
 
 block; use cudadevice
@@ -98,6 +90,28 @@ block; use cudadevice
     a(i) = a(i) * 2.0
     call syncthreads()
     a(i) = a(i) + a(j) - 34.0
+
+    call syncwarp(1)
+    ret = syncthreads_and(1)
+    ret = syncthreads_count(1)
+    ret = syncthreads_or(1)
   end do
 end block
 end 
+
+! CHECK-LABEL: func.func @_QPhost1()
+! CHECK: cuf.kernel
+! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
+! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
+! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+
+! CHECK: func.func private @llvm.nvvm.barrier0()
+! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
+! CHECK: func.func private @llvm.nvvm.membar.gl()
+! CHECK: func.func private @llvm.nvvm.membar.cta()
+! CHECK: func.func private @llvm.nvvm.membar.sys()
+! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32

@clementval clementval merged commit 69ccb13 into llvm:main Feb 5, 2025
9 of 10 checks passed
@clementval clementval deleted the cuf_sync_byvalue branch February 5, 2025 21:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants