17
17
18
18
19
19
@iron .jit
20
- def vector_vector_add (config , input0 , input1 , output ):
20
+ def vector_vector_add (input0 , input1 , output ):
21
21
if input0 .shape != input1 .shape :
22
22
raise ValueError (
23
23
f"Input shapes are not the equal ({ input0 .shape } != { input1 .shape } )."
@@ -48,16 +48,16 @@ def vector_vector_add(config, input0, input1, output):
48
48
49
49
buffer_depth = 2
50
50
51
- @device (config [ "device" ] )
51
+ @device (iron . get_current_device () )
52
52
def device_body ():
53
53
tensor_ty = np .ndarray [(num_elements ,), np .dtype [dtype ]]
54
54
tile_ty = np .ndarray [(n ,), np .dtype [dtype ]]
55
55
56
56
# AIE Core Function declarations
57
57
58
58
# Tile declarations
59
- ShimTile = tile (config [ "column_id" ] , 0 )
60
- ComputeTile2 = tile (config [ "column_id" ] , 2 )
59
+ ShimTile = tile (0 , 0 )
60
+ ComputeTile2 = tile (0 , 2 )
61
61
62
62
# AIE-array data movement with object fifos
63
63
of_in1 = object_fifo ("in1" , ShimTile , ComputeTile2 , buffer_depth , tile_ty )
@@ -114,9 +114,6 @@ def main():
114
114
default = "npu" ,
115
115
help = "Target device" ,
116
116
)
117
- parser .add_argument (
118
- "-c" , "--column" , type = int , default = 0 , help = "Column index (default: 0)"
119
- )
120
117
parser .add_argument (
121
118
"-n" ,
122
119
"--num-elements" ,
@@ -132,14 +129,11 @@ def main():
132
129
input1 = iron .randint (0 , 100 , (args .num_elements ,), dtype = np .int32 , device = "npu" )
133
130
output = iron .zeros_like (input0 )
134
131
132
+ iron .set_current_device (device_map [args .device ])
133
+
135
134
# JIT-compile the kernel then launches the kernel with the given arguments. Future calls
136
135
# to the kernel will use the same compiled kernel and loaded code objects
137
- vector_vector_add (
138
- {"device" : device_map [args .device ], "column_id" : args .column },
139
- input0 ,
140
- input1 ,
141
- output ,
142
- )
136
+ vector_vector_add (input0 , input1 , output )
143
137
144
138
# Check the correctness of the result
145
139
e = np .equal (input0 .numpy () + input1 .numpy (), output .numpy ())
0 commit comments