Open
Description
🐛 Describe the bug
#include <iostream>
#include <cassert>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
using namespace ::executorch::extension;
int main() {
Module module("mv2_xnnpack.pte");
const auto error = module.load();
assert(module.is_loaded());
float input[1 * 3 * 224 * 224];
auto tensor = from_blob(input, {1, 3, 224, 224});
const auto result = module.execute("forward", tensor);
if (result.ok()) {
// Retrieve the output data.
std::cout << "HERE" << std::endl;
const auto output = result->at(0).toTensor().const_data_ptr<float>();
}
}
Build with CMake:
cmake_minimum_required(VERSION 3.21)
project(exec)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(executorch_DIR "executorch/build1/executorch_install/lib/cmake/ExecuTorch")
find_package(executorch REQUIRED)
set(_common_compile_options -Wno-deprecated-declarations)
add_executable(exec main.cpp)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
target_link_libraries(
exec PRIVATE
${EXECUTORCH_LIBRARIES}
)
target_link_options(exec PRIVATE -Wl,--whole-archive,--allow-multiple-definition)
message(STATUS ${EXECUTORCH_LIBRARIES})
message(STATUS ${EXECUTORCH_INCLUDE_DIRS})
message(STATUS ${libs})
message(STATUS ${EXECUTORCH_LIBRARY})
Executorch build command:
cmake -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
-DBUILD_EXECUTORCH_PORTABLE_OPS=ON \
-DEXECUTORCH_BUILD_PYBIND=ON \
-DEXECUTORCH_BUILD_XNNPACK=ON ../ && \
cmake --build . -j$(nproc --all) && \
cmake --install . --prefix executorch_install
Get the error:
D 00:00:00.000184 executorch:operator_registry.cpp:92] Successfully registered all kernels from shared library: NOT_SUPPORTED
E 00:00:00.000222 executorch:operator_registry.cpp:85] Re-registering aten::sym_size.int, from NOT_SUPPORTED
E 00:00:00.000232 executorch:operator_registry.cpp:86] key: (null), is_fallback: true
F 00:00:00.000233 executorch:operator_registry.cpp:106] In function register_kernels(), assert failed (false): Kernel registration failed with error 18, see error log for details.
Aborted (core dumped)
Code of model export:
import torch
import torchvision.models as models
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )
et_program = to_edge_transform_and_lower(
torch.export.export(mobilenet_v2, sample_inputs),
#partitioner=[XnnpackPartitioner()],
).to_executorch()
with open("mv2_xnnpack.pte", "wb") as file:
et_program.write_to_file(file)
from pathlib import Path
from executorch.runtime import Verification, Runtime, Program, Method
et_runtime: Runtime = Runtime.get()
program: Program = et_runtime.load_program(
Path("mv2_xnnpack.pte"),
verification=Verification.Minimal,
)
print("Program methods:", program.method_names)
forward: Method = program.load_method("forward")
inputs = (torch.ones(1, 3, 224, 224),)
outputs = forward.execute(inputs)
print(f"Ran forward({inputs})")
print(f" outputs: {outputs}")
Versions
Collecting environment information...
PyTorch version: 2.6.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.31.6
Libc version: glibc-2.39
Python version: 3.10.0 (default, Mar 3 2022, 09:58:08) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 SUPER
Nvidia driver version: 572.42
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i5-14600KF
CPU family: 6
Model: 183
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 1
BogoMIPS: 6988.79
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 20 MiB (10 instances)
L3 cache: 24 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] executorch==0.5.0a0+1bc0699
[pip3] numpy==2.0.0
[pip3] torch==2.6.0+cpu
[pip3] torchao==0.8.0+gitebc43034
[pip3] torchaudio==2.6.0+cpu
[pip3] torchsr==1.0.4
[pip3] torchvision==0.21.0+cpu
[conda] executorch 0.5.0a0+1bc0699 pypi_0 pypi
[conda] numpy 2.0.0 pypi_0 pypi
[conda] torch 2.6.0+cpu pypi_0 pypi
[conda] torchao 0.8.0+gitebc43034 pypi_0 pypi
[conda] torchaudio 2.6.0+cpu pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.21.0+cpu pypi_0 pypi
I try to link builded executorch to my project with find_package
, but get the folowing error when try start inference. But if I use add_subdirectory(executorch)
like:
cmake_minimum_required(VERSION 3.21)
project(exec)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(_common_compile_options -Wno-deprecated-declarations)
set(EXECUTORCH_BUILD_EXECUTOR_RUNNER ON)
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
set(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL ON)
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
set(EXECUTORCH_BUILD_XNNPACK ON)
add_subdirectory("executorch")
add_executable(exec main.cpp)
target_link_libraries(
exec
PRIVATE executorch
extension_module_static
extension_tensor
portable_kernels
portable_ops_lib
executorch_core
#optimized_native_cpu_ops_lib
xnnpack_backend)
target_link_options(exec PRIVATE -Wl,--whole-archive,--allow-multiple-definition)
It works for me, but in my project I need to build executorch outside my project
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
To triage