Skip to content

Commit 7e38793

Browse files
authored
[flang][cuda] Make sure stream is a i64 reference (#157957)
When the stream is a scalar constant, it is lowered as i32. Stream needs to be i64 to pass the verifier. Detect and update the stream reference when it is i32.
1 parent 8fae5a5 commit 7e38793

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

flang/lib/Lower/ConvertCall.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,9 +639,18 @@ Fortran::lower::genCallOpAndResult(
639639
caller.getCallDescription().chevrons()[2], stmtCtx)));
640640

641641
mlir::Value stream; // stream is optional.
642-
if (caller.getCallDescription().chevrons().size() > 3)
642+
if (caller.getCallDescription().chevrons().size() > 3) {
643643
stream = fir::getBase(converter.genExprAddr(
644644
caller.getCallDescription().chevrons()[3], stmtCtx));
645+
if (!fir::unwrapRefType(stream.getType()).isInteger(64)) {
646+
auto i64Ty = mlir::IntegerType::get(builder.getContext(), 64);
647+
mlir::Value newStream = builder.createTemporary(loc, i64Ty);
648+
mlir::Value load = fir::LoadOp::create(builder, loc, stream);
649+
mlir::Value conv = fir::ConvertOp::create(builder, loc, i64Ty, load);
650+
fir::StoreOp::create(builder, loc, conv, newStream);
651+
stream = newStream;
652+
}
653+
}
645654

646655
cuf::KernelLaunchOp::create(builder, loc, funcType.getResults(),
647656
funcSymbolAttr, grid_x, grid_y, grid_z, block_x,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
attributes(global) subroutine sharedmem()
4+
real, shared :: s(*)
5+
integer :: t
6+
t = threadIdx%x
7+
s(t) = t
8+
end subroutine
9+
10+
program test
11+
call sharedmem<<<1, 1, 1024, 0>>>()
12+
end
13+
14+
! CHECK-LABEL: func.func @_QQmain()
15+
! CHECK: cuf.kernel_launch @_QPsharedmem<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c1024{{.*}}, %{{.*}} : !fir.ref<i64>>>>()

0 commit comments

Comments
 (0)