Skip to content

Commit f6b8e2a

Browse files
committed
Passing arbitrary configuration
1 parent 85f45e7 commit f6b8e2a

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

programming_examples/basic/vector_vector_add/vector_vector_add.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
@iron.jit(is_placed=False)
21-
def vector_vector_add(device, input0, input1, output):
21+
def vector_vector_add(config, input0, input1, output):
2222
if input0.shape != input1.shape:
2323
raise ValueError(
2424
f"Input shapes are not the equal ({input0.shape} != {input1.shape})."
@@ -81,7 +81,7 @@ def core_body(of_in1, of_in2, of_out):
8181
rt.drain(of_out.cons(), C, wait=True)
8282

8383
# Place program components (assign them resources on the device) and generate an MLIR module
84-
return Program(device, rt).resolve_program(SequentialPlacer())
84+
return Program(config['device'], rt).resolve_program(SequentialPlacer())
8585

8686

8787
def main():
@@ -118,7 +118,7 @@ def main():
118118

119119
# JIT-compile the kernel then launches the kernel with the given arguments. Future calls
120120
# to the kernel will use the same compiled kernel and loaded code objects
121-
vector_vector_add(device_map[args.device], input0, input1, output)
121+
vector_vector_add({'device': device_map[args.device]}, input0, input1, output)
122122

123123
# Check the correctness of the result
124124
e = np.equal(input0.numpy() + input1.numpy(), output.numpy())

programming_examples/basic/vector_vector_add/vector_vector_add_placed.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
@iron.jit
20-
def vector_vector_add(dev, column_id, input0, input1, output):
20+
def vector_vector_add(config, input0, input1, output):
2121
if input0.shape != input1.shape:
2222
raise ValueError(
2323
f"Input shapes are not the equal ({input0.shape} != {input1.shape})."
@@ -48,16 +48,16 @@ def vector_vector_add(dev, column_id, input0, input1, output):
4848

4949
buffer_depth = 2
5050

51-
@device(dev)
51+
@device(config['device'])
5252
def device_body():
5353
tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]]
5454
tile_ty = np.ndarray[(n,), np.dtype[dtype]]
5555

5656
# AIE Core Function declarations
5757

5858
# Tile declarations
59-
ShimTile = tile(column_id, 0)
60-
ComputeTile2 = tile(column_id, 2)
59+
ShimTile = tile(config['column_id'], 0)
60+
ComputeTile2 = tile(config['column_id'], 2)
6161

6262
# AIE-array data movement with object fifos
6363
of_in1 = object_fifo("in1", ShimTile, ComputeTile2, buffer_depth, tile_ty)
@@ -128,13 +128,22 @@ def main():
128128

129129
# Construct two input random tensors and an output zeroed tensor
130130
# The three tensor are in memory accessible to the NPU
131-
input0 = iron.randint(0, 100, (args.num_elements,), dtype=np.int32, device=args.device)
132-
input1 = iron.randint(0, 100, (args.num_elements,), dtype=np.int32, device=args.device)
131+
input0 = iron.randint(
132+
0, 100, (args.num_elements,), dtype=np.int32, device=args.device
133+
)
134+
input1 = iron.randint(
135+
0, 100, (args.num_elements,), dtype=np.int32, device=args.device
136+
)
133137
output = iron.zeros_like(input0)
134138

135139
# JIT-compile the kernel then launches the kernel with the given arguments. Future calls
136140
# to the kernel will use the same compiled kernel and loaded code objects
137-
vector_vector_add(device_map[args.device], args.column, input0, input1, output)
141+
vector_vector_add(
142+
{"device": device_map[args.device], "column_id": args.column},
143+
input0,
144+
input1,
145+
output,
146+
)
138147

139148
# Check the correctness of the result
140149
e = np.equal(input0.numpy() + input1.numpy(), output.numpy())

python/iron/jit.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
self.__insts_buffer_bo.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
8282

8383
# Blocking call.
84-
def __call__(self, *args):
84+
def __call__(self, config, *args):
8585
"""
8686
Allows the kernel to be called as a function with the provided arguments.
8787
@@ -101,7 +101,7 @@ def __call__(self, *args):
101101
h = self.__kernel(opcode, self.__insts_buffer_bo, self.__n_insts, *kernel_args)
102102
r = h.wait()
103103
if r != xrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED:
104-
raise Exception(f"Kernel returned {r}")
104+
raise NPUKernel_Error(f"Kernel returned {r}")
105105

106106
def __del__(self):
107107
"""
@@ -169,7 +169,7 @@ def wrapped_function(*args, **kwargs):
169169
)
170170

171171
kernel_name = "MLIR_AIE"
172-
return NPUKernel(xclbin_path, inst_path, kernel_name=kernel_name)
172+
return NPUKernel(xclbin_path, inst_path, kernel_name=kernel_name)(*args, **kwargs)
173173

174174
return wrapped_function
175175

0 commit comments

Comments
 (0)