From 22ff92c48b590fc6678d4ecf7486e38815cb5776 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Sun, 14 Aug 2022 17:11:19 -0500 Subject: [PATCH] Add config.VmModule argument to from_flatbuffer call. (#266) --- .github/workflows/test-models.yml | 4 ++-- shark/iree_utils/compile_utils.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml index cc91446efe..256a1fb9ab 100644 --- a/.github/workflows/test-models.yml +++ b/.github/workflows/test-models.yml @@ -85,7 +85,7 @@ jobs: PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh source shark.venv/bin/activate pytest --benchmark -k 'cpu' --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py - gsutil cp ./bench_results.csv gs://iree-shared-files/nod-perf/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv + gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv - name: Validate GPU Models if: matrix.suite == 'gpu' @@ -94,7 +94,7 @@ jobs: PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh source shark.venv/bin/activate pytest --benchmark -k "gpu" --ignore=shark/tests/test_shark_importer.py --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py - gsutil cp ./bench_results.csv gs://iree-shared-files/nod-perf/bench_results/${DATE}/bench_results_gpu_${SHORT_SHA}.csv + gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_gpu_${SHORT_SHA}.csv - name: Validate Vulkan Models if: matrix.suite == 'vulkan' diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index e409152dbe..02fd0a1c93 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -98,8 +98,10 @@ def compile_module_to_flatbuffer( def get_iree_module(flatbuffer_blob, device, func_name): # Returns the compiled module and the configs. - vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob) config = ireert.Config(IREE_DEVICE_MAP[device]) + vm_module = ireert.VmModule.from_flatbuffer( + config.vm_instance, flatbuffer_blob + ) ctx = ireert.SystemContext(config=config) ctx.add_vm_module(vm_module) ModuleCompiled = ctx.modules.module[func_name]