Skip to content

Commit

Permalink
feat(build): default to host features
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Mar 16, 2023
1 parent 843f3ea commit 4121145
Showing 1 changed file with 65 additions and 23 deletions.
88 changes: 65 additions & 23 deletions ggml-raw/build.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::collections::HashSet;
use std::env;
use std::path::PathBuf;

fn main() {
// By default, this crate will attempt to compile ggml with the features of your host system if
// the host and target are the same. If they are not, it will turn off auto-feature-detection,
// and you will need to manually specify target features through target-features.

let ggml_src = ["ggml/ggml.c"];

let mut builder = cc::Build::new();
Expand All @@ -11,45 +14,35 @@ fn main() {

// This is a very basic heuristic for applying compile flags.
// Feel free to update this to fit your operating system.
let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap();
let is_release = std::env::var("PROFILE").unwrap() == "release";

let supported_features: HashSet<_> = std::env::var("CARGO_CFG_TARGET_FEATURE")
.unwrap()
.split(',')
.map(|s| s.to_string())
.collect();
let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap();
let is_release = env::var("PROFILE").unwrap() == "release";

match target_arch.as_str() {
"x86" | "x86_64" => {
let supports_fma = supported_features.contains("fma");
let supports_avx = supported_features.contains("avx");
let supports_avx2 = supported_features.contains("avx2");
let supports_f16c = supported_features.contains("f16c");
let supports_sse3 = supported_features.contains("sse3");
let features = x86::Features::get();

match target_os.as_str() {
"freebsd" | "haiku" | "ios" | "macos" | "linux" => {
build.flag("-pthread");

if supports_avx {
if features.avx {
build.flag("-mavx");
}
if supports_avx2 {
if features.avx2 {
build.flag("-mavx2");
}
if supports_fma {
if features.fma {
build.flag("-mfma");
}
if supports_f16c {
if features.f16c {
build.flag("-mf16c");
}
if supports_sse3 {
if features.sse3 {
build.flag("-msse3");
}
}
"windows" => match (supports_avx2, supports_avx) {
"windows" => match (features.avx2, features.avx) {
(true, _) => {
build.flag("/arch:AVX2");
}
Expand All @@ -68,8 +61,6 @@ fn main() {
}
build.compile("ggml");

println!("cargo:rerun-if-changed=ggml/ggml.h");

let bindings = bindgen::Builder::default()
.header("ggml/ggml.h")
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
Expand All @@ -86,3 +77,54 @@ fn main() {
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}

fn get_supported_target_features() -> std::collections::HashSet<String> {
env::var("CARGO_CFG_TARGET_FEATURE")
.unwrap()
.split(',')
.map(|s| s.to_string())
.collect()
}

mod x86 {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Features {
pub fma: bool,
pub avx: bool,
pub avx2: bool,
pub f16c: bool,
pub sse3: bool,
}
impl Features {
pub fn get() -> Self {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if std::env::var("HOST") == std::env::var("TARGET") {
return Self::get_host();
}

Self::get_target()
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub fn get_host() -> Self {
Self {
fma: std::is_x86_feature_detected!("fma"),
avx: std::is_x86_feature_detected!("avx"),
avx2: std::is_x86_feature_detected!("avx2"),
f16c: std::is_x86_feature_detected!("f16c"),
sse3: std::is_x86_feature_detected!("sse3"),
}
}

pub fn get_target() -> Self {
let features = crate::get_supported_target_features();
Self {
fma: features.contains("fma"),
avx: features.contains("avx"),
avx2: features.contains("avx2"),
f16c: features.contains("f16c"),
sse3: features.contains("sse3"),
}
}
}
}

0 comments on commit 4121145

Please sign in to comment.