-
Notifications
You must be signed in to change notification settings - Fork 0
/
build.rs
65 lines (55 loc) · 2.13 KB
/
build.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
extern crate cc;
#[cfg(feature = "cuda")]
fn main() {
use std::env;
use std::path::PathBuf;
use std::fs;
// Check for the CUDA toolkit installation path
let cuda_path = match env::var("CUDA_PATH") {
Ok(path) => PathBuf::from(path),
Err(_) => PathBuf::from("/usr/local/cuda"), // fallback path for Linux
};
// Specify the CUDA library path to the linker
println!(
"cargo:rustc-link-search=native={}",
cuda_path.join("lib64").display()
);
// Link against the cudart library
println!("cargo:rustc-link-lib=cudart");
let cuda_dir = "cuda"; // Replace with your actual directory if different
// Compile all CUDA files in the directory
if let Ok(entries) = fs::read_dir(cuda_dir) {
for entry in entries.filter_map(Result::ok) {
if let Some(extension) = entry.path().extension() {
if extension == "cu" {
let mut build = cc::Build::new();
build.cuda(true);
build.flag("-cudart=shared");
// Add each gencode specification individually
let gencode_flags = [
"arch=compute_50,code=sm_50",
"arch=compute_60,code=sm_60",
"arch=compute_61,code=sm_61",
"arch=compute_70,code=sm_70",
"arch=compute_75,code=sm_75",
"arch=compute_80,code=sm_80",
];
for flag in &gencode_flags {
build.flag("-gencode").flag(flag);
}
build.file(entry.path())
.compile(format!("lib{}.a", entry.path().file_stem().unwrap().to_str().unwrap()).as_str());
}
}
}
} else {
eprintln!("Error: Could not read directory {}", cuda_dir);
}
println!("cargo::rerun-if-changed=build.rs");
println!("cargo::rerun-if-changed={}", cuda_dir);
println!("cargo:rustc-check-cfg=cfg(tarpaulin_include)");
}
#[cfg(not(feature = "cuda"))]
fn main() {
// Do nothing
}