|
19 | 19 |
|
20 | 20 | extern crate bindgen;
|
21 | 21 |
|
22 |
| -use std::env; |
23 |
| -use std::path::PathBuf; |
| 22 | +use std::path::{Path, PathBuf}; |
24 | 23 |
|
25 | 24 | use anyhow::{Context, Result};
|
| 25 | +use tvm_build::BuildConfig; |
| 26 | + |
| 27 | +/// The necessary information for detecting a TVM installation. |
| 28 | +struct TVMInstall { |
| 29 | + source_path: PathBuf, |
| 30 | + build_path: PathBuf, |
| 31 | +} |
| 32 | + |
| 33 | +/// Find the TVM install using the provided path. |
| 34 | +fn find_using_tvm_path<P: AsRef<Path>>(tvm_path: P) -> Result<TVMInstall> { |
| 35 | + Ok(TVMInstall { |
| 36 | + source_path: tvm_path.as_ref().into(), |
| 37 | + build_path: tvm_path.as_ref().into(), |
| 38 | + }) |
| 39 | +} |
| 40 | + |
| 41 | +#[allow(unused)] |
| 42 | +fn if_unset<K: AsRef<std::ffi::OsStr>, V: AsRef<std::ffi::OsStr>>(k: K, v: V) -> Result<()> { |
| 43 | + match std::env::var(k.as_ref()) { |
| 44 | + Ok(other) if other != "" => { |
| 45 | + println!( |
| 46 | + "cargo:warning=Using existing environment variable setting {:?}={:?}", |
| 47 | + k.as_ref(), |
| 48 | + v.as_ref() |
| 49 | + ); |
| 50 | + } |
| 51 | + _ => std::env::set_var(k, v), |
| 52 | + } |
| 53 | + |
| 54 | + Ok(()) |
| 55 | +} |
| 56 | + |
| 57 | +/// Find a TVM installation using TVM build by either first installing or detecting. |
| 58 | +fn find_using_tvm_build() -> Result<TVMInstall> { |
| 59 | + let mut build_config = BuildConfig::default(); |
| 60 | + build_config.repository = Some("https://github.com/apache/tvm".to_string()); |
| 61 | + build_config.branch = Some(option_env!("TVM_BRANCH").unwrap_or("main").into()); |
| 62 | + let build_result = tvm_build::build(build_config)?; |
| 63 | + let source_path = build_result.revision.source_path(); |
| 64 | + let build_path = build_result.revision.build_path(); |
| 65 | + Ok(TVMInstall { |
| 66 | + source_path, |
| 67 | + build_path, |
| 68 | + }) |
| 69 | +} |
26 | 70 |
|
27 | 71 | fn main() -> Result<()> {
|
28 |
| - let tvm_home = option_env!("TVM_HOME") |
29 |
| - .map::<Result<String>, _>(|s: &str| Ok(str::to_string(s))) |
30 |
| - .unwrap_or_else(|| { |
31 |
| - let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) |
32 |
| - .canonicalize() |
33 |
| - .with_context(|| { |
34 |
| - format!( |
35 |
| - "failed to cannonicalize() CARGO_MANIFEST_DIR={}", |
36 |
| - env!("CARGO_MANIFEST_DIR") |
37 |
| - ) |
38 |
| - })?; |
39 |
| - |
40 |
| - Ok(crate_dir |
41 |
| - .parent() |
42 |
| - .with_context(|| { |
43 |
| - format!( |
44 |
| - "failed to find parent of CARGO_MANIFEST_DIR={}", |
45 |
| - env!("CARGO_MANIFEST_DIR") |
46 |
| - ) |
47 |
| - })? |
48 |
| - .parent() |
49 |
| - .with_context(|| { |
50 |
| - format!( |
51 |
| - "failed to find the parent of the parent of CARGO MANIFEST_DIR={}", |
52 |
| - env!("CARGO_MANIFEST_DIR") |
53 |
| - ) |
54 |
| - })? |
55 |
| - .to_str() |
56 |
| - .context("failed to convert to strings")? |
57 |
| - .to_string()) |
58 |
| - })?; |
59 |
| - |
60 |
| - if cfg!(feature = "bindings") { |
61 |
| - println!("cargo:rerun-if-env-changed=TVM_HOME"); |
| 72 | + let TVMInstall { |
| 73 | + source_path, |
| 74 | + build_path, |
| 75 | + } = match option_env!("TVM_HOME") { |
| 76 | + Some(tvm_path) if tvm_path != "" => find_using_tvm_path(tvm_path), |
| 77 | + _ => find_using_tvm_build(), |
| 78 | + }?; |
| 79 | + |
| 80 | + // If the TVM_HOME environment variable changed, the LLVM_CONFIG_PATH environment variable |
| 81 | + // changed, the build directory or headers have changed we need to rebuild the Rust bindings. |
| 82 | + println!("cargo:rerun-if-env-changed=TVM_HOME"); |
| 83 | + println!("cargo:rerun-if-env-changed=LLVM_CONFIG_PATH"); |
| 84 | + println!("cargo:rerun-if-changed={}", build_path.display()); |
| 85 | + println!("cargo:rerun-if-changed={}/include", source_path.display()); |
| 86 | + |
| 87 | + if cfg!(feature = "static-linking") { |
| 88 | + println!("cargo:rustc-link-lib=static=tvm"); |
| 89 | + // TODO(@jroesch): move this to tvm-build as library_path? |
| 90 | + println!( |
| 91 | + "cargo:rustc-link-search=native={}/build", |
| 92 | + build_path.display() |
| 93 | + ); |
| 94 | + } |
| 95 | + |
| 96 | + if cfg!(feature = "dynamic-linking") { |
62 | 97 | println!("cargo:rustc-link-lib=dylib=tvm");
|
63 |
| - println!("cargo:rustc-link-search=native={}/build", tvm_home); |
| 98 | + println!( |
| 99 | + "cargo:rustc-link-search=native={}/build", |
| 100 | + build_path.display() |
| 101 | + ); |
64 | 102 | }
|
65 | 103 |
|
| 104 | + let runtime_api = source_path.join("include/tvm/runtime/c_runtime_api.h"); |
| 105 | + let backend_api = source_path.join("include/tvm/runtime/c_backend_api.h"); |
| 106 | + let source_path = source_path.display().to_string(); |
| 107 | + let dlpack_include = format!("-I{}/3rdparty/dlpack/include/", source_path); |
| 108 | + let tvm_include = format!("-I{}/include/", source_path); |
| 109 | + |
| 110 | + let out_file = PathBuf::from(std::env::var("OUT_DIR")?).join("c_runtime_api.rs"); |
| 111 | + |
66 | 112 | // @see rust-bindgen#550 for `blacklist_type`
|
67 | 113 | bindgen::Builder::default()
|
68 |
| - .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) |
69 |
| - .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) |
70 |
| - .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) |
71 |
| - .clang_arg(format!("-I{}/include/", tvm_home)) |
| 114 | + .header(runtime_api.display().to_string()) |
| 115 | + .header(backend_api.display().to_string()) |
| 116 | + .clang_arg(dlpack_include) |
| 117 | + .clang_arg(tvm_include) |
72 | 118 | .blacklist_type("max_align_t")
|
73 | 119 | .layout_tests(false)
|
74 | 120 | .derive_partialeq(true)
|
75 | 121 | .derive_eq(true)
|
76 | 122 | .derive_default(true)
|
77 | 123 | .generate()
|
78 |
| - .map_err(|()| anyhow::anyhow!("failed to generate bindings"))? |
79 |
| - .write_to_file(PathBuf::from("src/c_runtime_api.rs")) |
80 |
| - .context("failed to write bindings")?; |
| 124 | + .map_err(|()| { |
| 125 | + anyhow::anyhow!("bindgen failed to generate the Rust bindings for the C API") |
| 126 | + })? |
| 127 | + .write_to_file(out_file) |
| 128 | + .context("failed to write the generated Rust binding to disk")?; |
81 | 129 |
|
82 | 130 | Ok(())
|
83 | 131 | }
|
0 commit comments