@@ -52,7 +52,7 @@ def _init_llvm_types_and_constants(self):
52
52
self .byte_ptr_t = lc .Type .pointer (self .byte_t )
53
53
self .byte_ptr_ptr_t = lc .Type .pointer (self .byte_ptr_t )
54
54
self .intp_t = self .context .get_value_type (types .intp )
55
- self .long_t = self .context .get_value_type (types .int64 )
55
+ self .int64_t = self .context .get_value_type (types .int64 )
56
56
self .int32_t = self .context .get_value_type (types .int32 )
57
57
self .int32_ptr_t = lc .Type .pointer (self .int32_t )
58
58
self .uintp_t = self .context .get_value_type (types .uintp )
@@ -113,23 +113,26 @@ def allocate_kenrel_arg_array(self, num_kernel_args):
113
113
114
114
115
115
def resolve_and_return_dpctl_type (self , ty ):
116
+ """This function looks up the dpctl defined enum values from DPCTLKernelArgType.
117
+ """
118
+
116
119
val = None
117
120
if ty == types .int32 or isinstance (ty , types .scalars .IntegerLiteral ):
118
- val = self .context .get_constant (types .int32 , 4 )
121
+ val = self .context .get_constant (types .int32 , 9 ) # DPCTL_LONG_LONG
119
122
elif ty == types .uint32 :
120
- val = self .context .get_constant (types .int32 , 5 )
123
+ val = self .context .get_constant (types .int32 , 10 ) # DPCTL_UNSIGNED_LONG_LONG
121
124
elif ty == types .boolean :
122
- val = self .context .get_constant (types .int32 , 5 )
125
+ val = self .context .get_constant (types .int32 , 5 ) # DPCTL_UNSIGNED_INT
123
126
elif ty == types .int64 :
124
- val = self .context .get_constant (types .int32 , 7 )
127
+ val = self .context .get_constant (types .int32 , 9 ) # DPCTL_LONG_LONG
125
128
elif ty == types .uint64 :
126
- val = self .context .get_constant (types .int32 , 8 )
129
+ val = self .context .get_constant (types .int32 , 11 ) # DPCTL_SIZE_T
127
130
elif ty == types .float32 :
128
- val = self .context .get_constant (types .int32 , 12 )
131
+ val = self .context .get_constant (types .int32 , 12 ) # DPCTL_FLOAT
129
132
elif ty == types .float64 :
130
- val = self .context .get_constant (types .int32 , 13 )
133
+ val = self .context .get_constant (types .int32 , 13 ) # DPCTL_DOUBLE
131
134
elif ty == types .voidptr :
132
- val = self .context .get_constant (types .int32 , 15 )
135
+ val = self .context .get_constant (types .int32 , 15 ) # DPCTL_VOID_PTR
133
136
else :
134
137
raise NotImplementedError
135
138
@@ -151,12 +154,12 @@ def process_kernel_arg(self, var, llvm_arg, arg_type, gu_sig, val_type, index, m
151
154
if llvm_arg is None :
152
155
raise NotImplementedError (arg_type , var )
153
156
154
- storage = cgutils .alloca_once (self .builder , self .long_t )
157
+ storage = cgutils .alloca_once (self .builder , self .int64_t )
155
158
self .builder .store (self .context .get_constant (types .int64 , 0 ), storage )
156
159
ty = self .resolve_and_return_dpctl_type (types .int64 )
157
160
self .form_kernel_arg_and_arg_ty (self .builder .bitcast (storage , self .void_ptr_t ), ty )
158
161
159
- storage = cgutils .alloca_once (self .builder , self .long_t )
162
+ storage = cgutils .alloca_once (self .builder , self .int64_t )
160
163
self .builder .store (self .context .get_constant (types .int64 , 0 ), storage )
161
164
ty = self .resolve_and_return_dpctl_type (types .int64 )
162
165
self .form_kernel_arg_and_arg_ty (self .builder .bitcast (storage , self .void_ptr_t ), ty )
0 commit comments