Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for CONDA_OVERRIDE_CUDA #818

Merged
merged 11 commits into from
Aug 20, 2024
116 changes: 114 additions & 2 deletions crates/rattler_virtual_packages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,51 @@
pub mod libc;
pub mod linux;
pub mod osx;

Check warning on line 35 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 35 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs
use archspec::cpu::Microarchitecture;
use once_cell::sync::OnceCell;
use rattler_conda_types::{GenericVirtualPackage, PackageName, Platform, Version};
use rattler_conda_types::{GenericVirtualPackage, PackageName, ParseVersionError, Platform, Version};
use std::env;
use std::hash::{Hash, Hasher};
use std::str::FromStr;
use std::sync::Arc;

use crate::osx::ParseOsxVersionError;
use libc::DetectLibCError;
use linux::ParseLinuxVersionError;
use serde::{Deserialize, Deserializer, Serialize, Serializer};

/// Traits for overridable virtual packages
/// Use as `Cuda::from_default_env_var.unwrap_or(Cuda::current().into()).unwrap()`

Check warning on line 50 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 50 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs
pub trait EnvOverride: Sized {
/// Parse `env_var_value`
fn from_env_var_name_with_var(env_var_name: &str, env_var_value: &str) -> Result<Self, ParseVersionError>;

/// Read the environment variable and if it exists, try to parse it with [`EnvOverride::from_env_var_name_with_var`]
/// If the output is:
/// - `None`, then the environment variable did not exist,
/// - `Some(Err(None))`, then the environment variable exist but was set to zero, so the package should be disabled
/// - `Some(Ok(pkg))`, then the override was for the package.
fn from_env_var_name(env_var_name: &str) -> Option<Result<Self, Option<ParseVersionError>>> {
let var = env::var(env_var_name).ok()?;
if var.len() == 0 {
Some(Err(None))
} else {
Some(Self::from_env_var_name_with_var(env_var_name, &var).map_err(Some))
}
}

/// Default name of the environment variable that overrides the virtual package.
const DEFAULT_ENV_NAME: &'static str;

/// Shortcut for `EnvOverride::from_env_var_name(EnvOverride::DEFAULT_ENV_NAME)`.
fn from_default_env_var() -> Option<Result<Self, Option<ParseVersionError>>> {
Self::from_env_var_name(Self::DEFAULT_ENV_NAME)
}

Check warning on line 75 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 75 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs
}



/// An enum that represents all virtual package types provided by this library.
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub enum VirtualPackage {
Expand Down Expand Up @@ -188,6 +221,12 @@
}
}

impl From<Version> for Linux {
fn from(version: Version) -> Self {
Linux { version }
}
}

/// `LibC` virtual package description
#[derive(Clone, Eq, PartialEq, Hash, Debug, Deserialize)]
pub struct LibC {
Expand Down Expand Up @@ -228,7 +267,15 @@
VirtualPackage::LibC(libc)
}
}

Check warning on line 270 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 270 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs
impl EnvOverride for LibC {
const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_GLIBC";

fn from_env_var_name_with_var(_env_var_name: &str, env_var_value: &str) -> Result<Self, ParseVersionError> {
Version::from_str(env_var_value).map(|version| Self{family: "glibc".into(), version})
}
}

/// Cuda virtual package description
#[derive(Clone, Eq, PartialEq, Hash, Debug, Deserialize)]
pub struct Cuda {
Expand All @@ -243,6 +290,21 @@
}
}

impl From<Version> for Cuda {
fn from(version: Version) -> Self {
Self { version }
}
}

Check warning on line 297 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 297 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

impl EnvOverride for Cuda {
fn from_env_var_name_with_var(_env_var_name: &str, env_var_value: &str) -> Result<Self, ParseVersionError> {
Version::from_str(env_var_value).map(|version| Self{version})

}

const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_CUDA";
}

impl From<Cuda> for GenericVirtualPackage {
fn from(cuda: Cuda) -> Self {
GenericVirtualPackage {
Expand Down Expand Up @@ -359,7 +421,7 @@
GenericVirtualPackage {
name: PackageName::new_unchecked("__archspec"),
version: Version::major(1),
build_string: archspec.spec.name().to_string(),
build_string: archspec.spec.name().into(),
}
}
}
Expand Down Expand Up @@ -403,13 +465,63 @@
}
}

impl From<Version> for Osx {
fn from(version: Version) -> Self {
Self { version }
}
}

Check warning on line 472 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 472 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

impl EnvOverride for Osx {
fn from_env_var_name_with_var(_env_var_name: &str, env_var_value: &str) -> Result<Self, ParseVersionError> {
Version::from_str(env_var_value).map(|version| Self{version})
}

const DEFAULT_ENV_NAME: &'static str = "CONDA_OVERRIDE_OSX";
}

#[cfg(test)]
mod test {
use std::env;
use std::str::FromStr;

Check warning on line 486 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 486 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs
use rattler_conda_types::Version;

use crate::EnvOverride;
use crate::VirtualPackage;
use crate::LibC;
use crate::Cuda;
use crate::Osx;

#[test]
fn doesnt_crash() {
let virtual_packages = VirtualPackage::current().unwrap();
println!("{virtual_packages:?}");
}
#[test]

Check warning on line 500 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 500 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs
fn parse_libc() {
let v = "1.23";
let res = LibC{version: Version::from_str(v).unwrap(), family: "glibc".into()};
env::set_var(LibC::DEFAULT_ENV_NAME, v);
assert_eq!(LibC::from_default_env_var(), Some(Ok(res)));
env::set_var(LibC::DEFAULT_ENV_NAME, "");
assert_eq!(LibC::from_default_env_var(), Some(Err(None)));
env::remove_var(LibC::DEFAULT_ENV_NAME);
assert_eq!(LibC::from_default_env_var(), None);
}

#[test]

Check warning on line 512 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 512 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs
fn parse_cuda() {
let v = "1.234";
let res = Cuda{version: Version::from_str(v).unwrap()};
env::set_var(Cuda::DEFAULT_ENV_NAME, v);
assert_eq!(Cuda::from_default_env_var(), Some(Ok(res)));
}

#[test]

Check warning on line 520 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format, Lint and Test the Python bindings

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs

Check warning on line 520 in crates/rattler_virtual_packages/src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/rattler/rattler/crates/rattler_virtual_packages/src/lib.rs
fn parse_osx() {
let v = "2.345";
let res = Osx{version: Version::from_str(v).unwrap()};
env::set_var(Osx::DEFAULT_ENV_NAME, v);
assert_eq!(Osx::from_default_env_var(), Some(Ok(res)));
}
}
Loading