Skip to content

Commit

Permalink
Added cuBLAS/Windows support
Browse files Browse the repository at this point in the history
  • Loading branch information
darxkies committed Jun 8, 2023
1 parent 8006e1d commit 4d12f70
Showing 1 changed file with 78 additions and 43 deletions.
121 changes: 78 additions & 43 deletions crates/ggml/sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,56 +107,91 @@ fn enable_clblast(build: &mut cc::Build) {

#[cfg(feature = "cublas")]
fn enable_cublas(build: &mut cc::Build) {
let targets_include = concat!(env!("CUDA_PATH"), "/targets/x86_64-linux/include");
let targets_lib = concat!(env!("CUDA_PATH"), "/targets/x86_64-linux/lib");

let out_dir = env::var("OUT_DIR").unwrap();

let object_file = format!("{}/ggml/src/ggml-cuda.o", &out_dir);
let object_file = format!(r"{}\ggml\src\ggml-cuda.o", &out_dir);

let path = std::path::Path::new(&object_file);
let parent_dir = path.parent().unwrap();

std::fs::create_dir_all(parent_dir).unwrap();

let parameters = [
"--forward-unknown-to-host-compiler",
"-O3",
"-std=c++11",
"-fPIC",
"-Iggml/include/ggml",
"-mtune=native",
"-pthread",
"-DGGML_USE_CUBLAS",
"-I/usr/local/cuda/include",
"-I/opt/cuda/include",
"-I",
targets_include,
"-c",
"ggml/src/ggml-cuda.cu",
"-o",
&object_file,
];

std::process::Command::new("nvcc")
.args(parameters)
.status()
.unwrap();

println!("cargo:rustc-link-search=native={}", targets_lib);
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-search=native=/opt/cuda/lib64");
println!("cargo:rustc-link-lib=cublas");
println!("cargo:rustc-link-lib=culibos");
println!("cargo:rustc-link-lib=cudart");
println!("cargo:rustc-link-lib=cublasLt");
println!("cargo:rustc-link-lib=dylib=stdc++");

build.object(object_file);
build.flag("-DGGML_USE_CUBLAS");
build.include("/usr/local/cuda/include");
build.include("/opt/cuda/include");
build.include(targets_include);
if cfg!(windows) {
let targets_include = concat!(env!("CUDA_PATH"), r"\include");
let targets_lib = concat!(env!("CUDA_PATH"), r"\lib\x64");

std::process::Command::new("nvcc")
.arg("-ccbin")
.arg(cc::Build::new().get_compiler().path().parent().unwrap().join("cl.exe"))
.arg("-I")
.arg(targets_include)
.arg("-o")
.arg(&object_file)
.arg("-x")
.arg("cu")
.arg("-maxrregcount=0")
.arg("--machine")
.arg("64")
.arg("--compile")
.arg("-cudart")
.arg("static")
.arg("-D_WINDOWS")
.arg("-DNDEBUG")
.arg("-DGGML_USE_CUBLAS")
.arg("-D_CRT_SECURE_NO_WARNINGS")
.arg("-D_MBCS")
.arg("-DWIN32")
.arg(r"-Iggml\include\ggml")
.arg(r"ggml\src\ggml-cuda.cu")
.status()
.unwrap();

println!("cargo:rustc-link-search=native={}", targets_lib);
println!("cargo:rustc-link-lib=cublas");
println!("cargo:rustc-link-lib=cudart");
println!("cargo:rustc-link-lib=cublasLt");

build.object(object_file);
build.flag("-DGGML_USE_CUBLAS");
build.include(targets_include);
} else {
let targets_include = concat!(env!("CUDA_PATH"), "/targets/x86_64-linux/include");
let targets_lib = concat!(env!("CUDA_PATH"), "/targets/x86_64-linux/lib");

std::process::Command::new("nvcc")
.arg("--forward-unknown-to-host-compiler")
.arg("-O3")
.arg("-std=c++11")
.arg("-fPIC")
.arg("-Iggml/include/ggml")
.arg("-mtune=native")
.arg("-pthread")
.arg("-DGGML_USE_CUBLAS")
.arg("-I/usr/local/cuda/include")
.arg("-I/opt/cuda/include")
.arg("-I")
.arg(targets_include)
.arg("-c")
.arg("ggml/src/ggml-cuda.cu")
.arg("-o")
.arg(&object_file)
.status()
.unwrap();

println!("cargo:rustc-link-search=native={}", targets_lib);
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-search=native=/opt/cuda/lib64");
println!("cargo:rustc-link-lib=cublas");
println!("cargo:rustc-link-lib=culibos");
println!("cargo:rustc-link-lib=cudart");
println!("cargo:rustc-link-lib=cublasLt");
println!("cargo:rustc-link-lib=dylib=stdc++");

build.object(object_file);
build.flag("-DGGML_USE_CUBLAS");
build.include("/usr/local/cuda/include");
build.include("/opt/cuda/include");
build.include(targets_include);
}
}

fn get_supported_target_features() -> std::collections::HashSet<String> {
Expand Down

0 comments on commit 4d12f70

Please sign in to comment.