Skip to content

Commit 648f53a

Browse files
authored
[Rust] Make TVM Rust bindings installable via Cargo. (#7503)
* Make TVM Cargo installable Rewrite the Rust Module API and change some imports causing crashes. This commit also updates the docs to remove outdated information. Fixes for version bump Update build.rs to use new tvm-build version Tweak build.rs to use release version of tvm-build Add docs Add Readme for tvm-sys crate. Fix Cargo verisions for pre-release Add README Move generated code to OUT_DIR Fix path Add descp for tvm-sys Tweak versions for publishing Tweak versions for publishing Add README for tvm-graph-rt Conform to Apache branding guidelines Fix caps Add header Remove warning Format Clean up build Turn docs back on Tweak CI WIP Remove CI changes * Disable docs * Fix
1 parent e426458 commit 648f53a

File tree

14 files changed

+196
-75
lines changed

14 files changed

+196
-75
lines changed

rust/tvm-graph-rt/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
[package]
1919
name = "tvm-graph-rt"
20-
version = "0.1.0"
20+
version = "0.1.0-alpha"
2121
license = "Apache-2.0"
2222
description = "A static graph executor for TVM."
2323
repository = "https://github.com/apache/tvm"
@@ -38,8 +38,8 @@ nom = "5.0"
3838
num_cpus = "1.10"
3939
serde = { version = "^1.0", features = ["derive"] }
4040
serde_json = "^1.0"
41-
tvm-sys = { version = "0.1", path = "../tvm-sys" }
42-
tvm-macros = { version = "0.1", path = "../tvm-macros" }
41+
tvm-sys = { version = "0.1.1-alpha", path = "../tvm-sys" }
42+
tvm-macros = { version = "0.1.1-alpha", path = "../tvm-macros" }
4343

4444
[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
4545
libloading = "0.5"

rust/tvm-graph-rt/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
2+
<!--- or more contributor license agreements. See the NOTICE file -->
3+
<!--- distributed with this work for additional information -->
4+
<!--- regarding copyright ownership. The ASF licenses this file -->
5+
<!--- to you under the Apache License, Version 2.0 (the -->
6+
<!--- "License"); you may not use this file except in compliance -->
7+
<!--- with the License. You may obtain a copy of the License at -->
8+
9+
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
10+
11+
<!--- Unless required by applicable law or agreed to in writing, -->
12+
<!--- software distributed under the License is distributed on an -->
13+
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
14+
<!--- KIND, either express or implied. See the License for the -->
15+
<!--- specific language governing permissions and limitations -->
16+
<!--- under the License. -->
17+
18+
# tvm-graph-rt
19+
20+
An implementation of TVM's graph runtime in Rust. See `tvm` crate for more documentation.

rust/tvm-macros/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
[package]
1919
name = "tvm-macros"
20-
version = "0.1.1"
20+
version = "0.1.1-alpha"
2121
license = "Apache-2.0"
2222
description = "Procedural macros of the TVM crate."
2323
repository = "https://github.com/apache/tvm"

rust/tvm-macros/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
2+
<!--- or more contributor license agreements. See the NOTICE file -->
3+
<!--- distributed with this work for additional information -->
4+
<!--- regarding copyright ownership. The ASF licenses this file -->
5+
<!--- to you under the Apache License, Version 2.0 (the -->
6+
<!--- "License"); you may not use this file except in compliance -->
7+
<!--- with the License. You may obtain a copy of the License at -->
8+
9+
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
10+
11+
<!--- Unless required by applicable law or agreed to in writing, -->
12+
<!--- software distributed under the License is distributed on an -->
13+
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
14+
<!--- KIND, either express or implied. See the License for the -->
15+
<!--- specific language governing permissions and limitations -->
16+
<!--- under the License. -->
17+
18+
# tvm-macros
19+
20+
The procedural macro implementations for TVM crates, see `tvm` crate for more documentation.

rust/tvm-rt/Cargo.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
[package]
1919
name = "tvm-rt"
20-
version = "0.1.0"
20+
version = "0.1.0-alpha"
2121
license = "Apache-2.0"
2222
description = "Rust bindings for the TVM runtime API."
2323
repository = "https://github.com/apache/tvm"
@@ -30,22 +30,22 @@ edition = "2018"
3030

3131
[features]
3232
default = ["dynamic-linking"]
33-
dynamic-linking = ["tvm-sys/bindings"]
34-
static-linking = []
33+
dynamic-linking = ["tvm-sys/dynamic-linking"]
34+
static-linking = ["tvm-sys/static-linking"]
3535
blas = ["ndarray/blas"]
3636

3737
[dependencies]
3838
thiserror = "^1.0"
3939
ndarray = "0.12"
4040
num-traits = "0.2"
41-
tvm-macros = { version = "0.1", path = "../tvm-macros" }
41+
tvm-macros = { version = "0.1.1-alpha", path = "../tvm-macros" }
4242
paste = "0.1"
4343
mashup = "0.1"
4444
once_cell = "^1.3.1"
4545
memoffset = "0.5.6"
4646

4747
[dependencies.tvm-sys]
48-
version = "0.1"
48+
version = "0.1.1-alpha"
4949
default-features = false
5050
path = "../tvm-sys/"
5151

rust/tvm-rt/src/ndarray.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ impl NDArray {
287287
check_call!(ffi::TVMArrayCopyFromBytes(
288288
self.as_raw_dltensor(),
289289
data.as_ptr() as *mut _,
290-
data.len() * mem::size_of::<T>()
290+
(data.len() * mem::size_of::<T>()) as _,
291291
));
292292
}
293293

@@ -296,7 +296,7 @@ impl NDArray {
296296
check_call!(ffi::TVMArrayCopyToBytes(
297297
self.as_raw_dltensor(),
298298
data.as_ptr() as *mut _,
299-
self.size(),
299+
self.size() as _,
300300
));
301301
}
302302

rust/tvm-sys/Cargo.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
[package]
1919
name = "tvm-sys"
20-
version = "0.1.0"
20+
version = "0.1.1-alpha"
2121
authors = ["TVM Contributors"]
2222
license = "Apache-2.0"
2323
edition = "2018"
24+
description = "Low level bindings to TVM's cross language API."
2425

2526
[features]
26-
default = []
27-
bindings = []
27+
default = ["dynamic-linking"]
28+
static-linking = []
29+
dynamic-linking = []
2830

2931
[dependencies]
3032
thiserror = "^1.0"
@@ -33,5 +35,6 @@ ndarray = "0.12"
3335
enumn = "^0.1"
3436

3537
[build-dependencies]
36-
bindgen = { version="0.51", default-features=false }
38+
bindgen = { version="0.57", default-features = false, features = ["runtime"] }
3739
anyhow = "^1.0"
40+
tvm-build = "0.1"

rust/tvm-sys/README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
2+
<!--- or more contributor license agreements. See the NOTICE file -->
3+
<!--- distributed with this work for additional information -->
4+
<!--- regarding copyright ownership. The ASF licenses this file -->
5+
<!--- to you under the Apache License, Version 2.0 (the -->
6+
<!--- "License"); you may not use this file except in compliance -->
7+
<!--- with the License. You may obtain a copy of the License at -->
8+
9+
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
10+
11+
<!--- Unless required by applicable law or agreed to in writing, -->
12+
<!--- software distributed under the License is distributed on an -->
13+
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
14+
<!--- KIND, either express or implied. See the License for the -->
15+
<!--- specific language governing permissions and limitations -->
16+
<!--- under the License. -->
17+
18+
# tvm-sys
19+
20+
The low level bindings to TVM's C APIs for interacting with the runtime,
21+
the cross-language object system, and packed function API.
22+
23+
These will generate bindings to TVM, if you set `TVM_HOME` variable before
24+
building it will instruct the bindings to use your source tree, if not the
25+
crate will use `tvm-build` in order to build a sandboxed version of the library.
26+
27+
This feature is intended to simplify the installation for brand new TVM users
28+
by trying to automate the build process as much as possible.

rust/tvm-sys/build.rs

Lines changed: 92 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,65 +19,113 @@
1919

2020
extern crate bindgen;
2121

22-
use std::env;
23-
use std::path::PathBuf;
22+
use std::path::{Path, PathBuf};
2423

2524
use anyhow::{Context, Result};
25+
use tvm_build::BuildConfig;
26+
27+
/// The necessary information for detecting a TVM installation.
28+
struct TVMInstall {
29+
source_path: PathBuf,
30+
build_path: PathBuf,
31+
}
32+
33+
/// Find the TVM install using the provided path.
34+
fn find_using_tvm_path<P: AsRef<Path>>(tvm_path: P) -> Result<TVMInstall> {
35+
Ok(TVMInstall {
36+
source_path: tvm_path.as_ref().into(),
37+
build_path: tvm_path.as_ref().into(),
38+
})
39+
}
40+
41+
#[allow(unused)]
42+
fn if_unset<K: AsRef<std::ffi::OsStr>, V: AsRef<std::ffi::OsStr>>(k: K, v: V) -> Result<()> {
43+
match std::env::var(k.as_ref()) {
44+
Ok(other) if other != "" => {
45+
println!(
46+
"cargo:warning=Using existing environment variable setting {:?}={:?}",
47+
k.as_ref(),
48+
v.as_ref()
49+
);
50+
}
51+
_ => std::env::set_var(k, v),
52+
}
53+
54+
Ok(())
55+
}
56+
57+
/// Find a TVM installation using TVM build by either first installing or detecting.
58+
fn find_using_tvm_build() -> Result<TVMInstall> {
59+
let mut build_config = BuildConfig::default();
60+
build_config.repository = Some("https://github.com/apache/tvm".to_string());
61+
build_config.branch = Some(option_env!("TVM_BRANCH").unwrap_or("main").into());
62+
let build_result = tvm_build::build(build_config)?;
63+
let source_path = build_result.revision.source_path();
64+
let build_path = build_result.revision.build_path();
65+
Ok(TVMInstall {
66+
source_path,
67+
build_path,
68+
})
69+
}
2670

2771
fn main() -> Result<()> {
28-
let tvm_home = option_env!("TVM_HOME")
29-
.map::<Result<String>, _>(|s: &str| Ok(str::to_string(s)))
30-
.unwrap_or_else(|| {
31-
let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
32-
.canonicalize()
33-
.with_context(|| {
34-
format!(
35-
"failed to cannonicalize() CARGO_MANIFEST_DIR={}",
36-
env!("CARGO_MANIFEST_DIR")
37-
)
38-
})?;
39-
40-
Ok(crate_dir
41-
.parent()
42-
.with_context(|| {
43-
format!(
44-
"failed to find parent of CARGO_MANIFEST_DIR={}",
45-
env!("CARGO_MANIFEST_DIR")
46-
)
47-
})?
48-
.parent()
49-
.with_context(|| {
50-
format!(
51-
"failed to find the parent of the parent of CARGO MANIFEST_DIR={}",
52-
env!("CARGO_MANIFEST_DIR")
53-
)
54-
})?
55-
.to_str()
56-
.context("failed to convert to strings")?
57-
.to_string())
58-
})?;
59-
60-
if cfg!(feature = "bindings") {
61-
println!("cargo:rerun-if-env-changed=TVM_HOME");
72+
let TVMInstall {
73+
source_path,
74+
build_path,
75+
} = match option_env!("TVM_HOME") {
76+
Some(tvm_path) if tvm_path != "" => find_using_tvm_path(tvm_path),
77+
_ => find_using_tvm_build(),
78+
}?;
79+
80+
// If the TVM_HOME environment variable changed, the LLVM_CONFIG_PATH environment variable
81+
// changed, the build directory or headers have changed we need to rebuild the Rust bindings.
82+
println!("cargo:rerun-if-env-changed=TVM_HOME");
83+
println!("cargo:rerun-if-env-changed=LLVM_CONFIG_PATH");
84+
println!("cargo:rerun-if-changed={}", build_path.display());
85+
println!("cargo:rerun-if-changed={}/include", source_path.display());
86+
87+
if cfg!(feature = "static-linking") {
88+
println!("cargo:rustc-link-lib=static=tvm");
89+
// TODO(@jroesch): move this to tvm-build as library_path?
90+
println!(
91+
"cargo:rustc-link-search=native={}/build",
92+
build_path.display()
93+
);
94+
}
95+
96+
if cfg!(feature = "dynamic-linking") {
6297
println!("cargo:rustc-link-lib=dylib=tvm");
63-
println!("cargo:rustc-link-search=native={}/build", tvm_home);
98+
println!(
99+
"cargo:rustc-link-search=native={}/build",
100+
build_path.display()
101+
);
64102
}
65103

104+
let runtime_api = source_path.join("include/tvm/runtime/c_runtime_api.h");
105+
let backend_api = source_path.join("include/tvm/runtime/c_backend_api.h");
106+
let source_path = source_path.display().to_string();
107+
let dlpack_include = format!("-I{}/3rdparty/dlpack/include/", source_path);
108+
let tvm_include = format!("-I{}/include/", source_path);
109+
110+
let out_file = PathBuf::from(std::env::var("OUT_DIR")?).join("c_runtime_api.rs");
111+
66112
// @see rust-bindgen#550 for `blacklist_type`
67113
bindgen::Builder::default()
68-
.header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
69-
.header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
70-
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
71-
.clang_arg(format!("-I{}/include/", tvm_home))
114+
.header(runtime_api.display().to_string())
115+
.header(backend_api.display().to_string())
116+
.clang_arg(dlpack_include)
117+
.clang_arg(tvm_include)
72118
.blacklist_type("max_align_t")
73119
.layout_tests(false)
74120
.derive_partialeq(true)
75121
.derive_eq(true)
76122
.derive_default(true)
77123
.generate()
78-
.map_err(|()| anyhow::anyhow!("failed to generate bindings"))?
79-
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
80-
.context("failed to write bindings")?;
124+
.map_err(|()| {
125+
anyhow::anyhow!("bindgen failed to generate the Rust bindings for the C API")
126+
})?
127+
.write_to_file(out_file)
128+
.context("failed to write the generated Rust binding to disk")?;
81129

82130
Ok(())
83131
}

rust/tvm-sys/src/byte_array.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ pub struct ByteArray {
4141
impl ByteArray {
4242
/// Gets the underlying byte-array
4343
pub fn data(&self) -> &'static [u8] {
44-
unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size) }
44+
unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size as _) }
4545
}
4646

4747
/// Gets the length of the underlying byte-array
4848
pub fn len(&self) -> usize {
49-
self.array.size
49+
self.array.size as _
5050
}
5151

5252
/// Converts the underlying byte-array to `Vec<u8>`
@@ -66,7 +66,7 @@ impl<T: AsRef<[u8]>> From<T> for ByteArray {
6666
ByteArray {
6767
array: TVMByteArray {
6868
data: arg.as_ptr() as *const c_char,
69-
size: arg.len(),
69+
size: arg.len() as _,
7070
},
7171
}
7272
}

0 commit comments

Comments
 (0)