Skip to content

Commit

Permalink
Add e2e test with bfloat16 (Xilinx#446)
Browse files Browse the repository at this point in the history
Add an end-to-end test with bfloat16 inputs and f32 outputs.
  • Loading branch information
fifield authored Feb 22, 2024
1 parent e56552e commit edba97f
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 0 deletions.
118 changes: 118 additions & 0 deletions test/xrt/06_add_shim_bf16/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# gen.py -*- Python -*-
#
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

import air
from air.ir import *
from air.passmanager import *
from air.dialects import air as airdialect
from air.dialects import arith, func, linalg, memref
from air.dialects.linalg.opdsl.lang import *
from air._mlir_libs._airMlir import _run_air_transform as run_air_transform

def generate_add_module(shape, dtype):
module = Module.create()
with InsertionPoint(module.body):
@func.FuncOp.from_py_func(
MemRefType.get(shape, dtype), MemRefType.get(shape, dtype), MemRefType.get(shape, F32Type.get()))
def mul(lhs, rhs, out):
linalg.elemwise_binary(
lhs,
rhs,
outs=[out],
fun=BinaryFn.add,
cast=TypeFn.cast_unsigned)
return

#print ("\nlinalg Module:\n\n", module)

transform_ir_string = """
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_copy : benefit(1) {
%args = pdl.operands
%results = pdl.types
%op = pdl.operation "memref.copy"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
pdl.rewrite %op with "transform.dialect"
}
transform.sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%l0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
%l1, %herd_tile_loops = transform.air.linalg_tile %l0 [0,128]
%l3, %inner_tile_loops:2 = transform.air.linalg_tile %l1 [32,32]
transform.air.linalg_promote %l3 {"operands_to_promote"=[0,1,2], "memory_space"="L1"}
%herd = transform.air.par_to_herd %herd_tile_loops
%copies = transform.pdl_match @match_copy in %arg0 : (!pdl.operation) -> !pdl.operation
%h = transform.air.copy_to_dma %copies
}
}
"""

pm = PassManager.parse('builtin.module(func.func(linalg-generalize-named-ops))')
pm.run(module.operation)
transform_ir = Module.parse(transform_ir_string)
run_air_transform(transform_ir, module)

pm = PassManager.parse('builtin.module(func.func(canonicalize,cse))')
pm.run(module.operation)
return module

with Context() as ctx, Location.unknown():
airdialect.register_dialect(ctx)
mlir_module = generate_add_module([128,128], BF16Type.get())

print("\nTiled AIR Module:\n\n", mlir_module)
# with open("add.air.mlir", "w") as f:
# f.write(str(mlir_module))

pipeline = "builtin.module(" + ",".join([
"func.func(air-lower-herd-parallel)",
#"air-dependency",
#"air-dependency-schedule-opt",
"air-dma-to-channel",
"canonicalize", "cse",
"air-specialize-channel-wrap-and-stride",
"func.func(convert-linalg-to-loops)",
'func.func(air-renumber-dma)'
]) + ")"
pm = PassManager.parse(pipeline)
pm.run(mlir_module.operation)

# print("\nAIE Module:\n\n", mlir_module)
# with open("add.chan.mlir", "w") as f:
# f.write(str(mlir_module))

pipeline = "builtin.module(" + ",".join([
"air-to-aie{emit-while-loop=true device=ipu row-offset=2 col-offset=0 use-objectfifo=false}",
"air-to-std",
"canonicalize", "cse",
]) + ")"
pm = PassManager.parse(pipeline)
pm.run(mlir_module.operation)

# print("\nAIE Module:\n\n", mlir_module)
# with open("add.aieairrt.mlir", "w") as f:
# f.write(str(mlir_module))

pipeline = "builtin.module(" + ",".join([
"airrt-to-ipu",
"canonicalize", "cse",
]) + ")"
pm = PassManager.parse(pipeline)
pm.run(mlir_module.operation)

# print("\nAIE Module:\n\n", mlir_module)
# with open("add.aieipu.mlir", "w") as f:
# f.write(str(mlir_module))

import aie.compiler.aiecc.main as aiecc

aiecc_options = ['--no-aiesim',
'--aie-generate-cdo',
'--aie-generate-ipu',
'--no-compile-host',
'--ipu-insts-name=insts.txt',
'--xclbin-name=add.xclbin',
'aie.mlir']
aiecc.run(mlir_module, aiecc_options)
6 changes: 6 additions & 0 deletions test/xrt/06_add_shim_bf16/run.lit
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// (c) Copyright 2024 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// REQUIRES: ryzen_ai, valid_xchess_license
// RUN: %python %S/gen.py
// RUN: %run_on_ipu %python %S/run.py add.xclbin
67 changes: 67 additions & 0 deletions test/xrt/06_add_shim_bf16/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# run.py -*- Python -*-
#
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

import numpy as np
from bfloat16 import bfloat16
import pyxrt as xrt

in_size = out_size = 128*128

in_size_bytes = in_size * 2
out_size_bytes = out_size * 4

with open('insts.txt', 'r') as f:
instr_text = f.read().split('\n')
instr_text = [l for l in instr_text if l != '']
instr_v = np.array([int(i,16) for i in instr_text], dtype=np.uint32)

opts_xclbin = 'add.xclbin'
opts_kernel = 'MLIR_AIE'

device = xrt.device(0)
xclbin = xrt.xclbin(opts_xclbin)
kernels = xclbin.get_kernels()
try:
xkernel = [k for k in kernels if opts_kernel in k.get_name()][0]
except:
print(f"Kernel '{opts_kernel}' not found in '{opts_xclbin}'")
exit(-1)

device.register_xclbin(xclbin)
context = xrt.hw_context(device, xclbin.get_uuid())
kernel = xrt.kernel(context, xkernel.get_name())

bo_instr = xrt.bo(device, len(instr_v)*4, xrt.bo.cacheable, kernel.group_id(0))
bo_a = xrt.bo(device, in_size_bytes, xrt.bo.host_only, kernel.group_id(2))
bo_b = xrt.bo(device, in_size_bytes, xrt.bo.host_only, kernel.group_id(3))
bo_c = xrt.bo(device, out_size_bytes, xrt.bo.host_only, kernel.group_id(4))

bo_instr.write(instr_v, 0)
bo_instr.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)

input_a = np.random.rand(in_size).astype(bfloat16)
bo_a.write(input_a.view(np.int16), 0)
bo_a.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)

input_b = np.random.rand(in_size).astype(bfloat16)
bo_b.write(input_b.view(np.int16), 0)
bo_b.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)

h = kernel(bo_instr, len(instr_v), bo_a, bo_b, bo_c)
h.wait()

bo_c.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE)
output_buffer = bo_c.read(out_size_bytes, 0).view(np.float32)
print("input:", input_a)
print("input:", input_b)
print("output:", output_buffer)

ref = input_a + input_b
if np.equal(ref, output_buffer).all():
print("PASS!")
exit(0)
else:
print("failed.")
exit(-1)
1 change: 1 addition & 0 deletions utils/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ rich
pybind11
numpy
cmake
bfloat16

0 comments on commit edba97f

Please sign in to comment.