Skip to content

Commit

Permalink
Add config.VmModule argument to from_flatbuffer call. (nod-ai#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored Aug 14, 2022
1 parent 7f5aaa3 commit 22ff92c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark -k 'cpu' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
gsutil cp ./bench_results.csv gs://iree-shared-files/nod-perf/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
- name: Validate GPU Models
if: matrix.suite == 'gpu'
Expand All @@ -94,7 +94,7 @@ jobs:
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --benchmark -k "gpu" --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py
gsutil cp ./bench_results.csv gs://iree-shared-files/nod-perf/bench_results/${DATE}/bench_results_gpu_${SHORT_SHA}.csv
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_gpu_${SHORT_SHA}.csv
- name: Validate Vulkan Models
if: matrix.suite == 'vulkan'
Expand Down
4 changes: 3 additions & 1 deletion shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ def compile_module_to_flatbuffer(

def get_iree_module(flatbuffer_blob, device, func_name):
# Returns the compiled module and the configs.
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
config = ireert.Config(IREE_DEVICE_MAP[device])
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = ctx.modules.module[func_name]
Expand Down

0 comments on commit 22ff92c

Please sign in to comment.