diff --git a/examples/my-project/flake.nix b/examples/my-project/flake.nix index b6750ed..f3df42d 100644 --- a/examples/my-project/flake.nix +++ b/examples/my-project/flake.nix @@ -29,7 +29,7 @@ inputs.nix-ml-ops.flakeModules.kubernetesJob inputs.nix-ml-ops.flakeModules.kubernetesService inputs.nix-ml-ops.flakeModules.devcontainer - inputs.nix-ml-ops.flakeModules.linkNvidiaDrivers + inputs.nix-ml-ops.flakeModules.cuda inputs.nix-ml-ops.flakeModules.volumeMountNfs inputs.nix-ml-ops.flakeModules.devcontainerGcpCliTools inputs.nix-ml-ops.flakeModules.gkeCredential diff --git a/flake-modules/cuda.nix b/flake-modules/cuda.nix index b37c917..7b90a08 100644 --- a/flake-modules/cuda.nix +++ b/flake-modules/cuda.nix @@ -2,12 +2,10 @@ topLevel@{ flake-parts-lib, inputs, ... }: { imports = [ inputs.flake-parts.flakeModules.flakeModules ./common.nix - ./link-nvidia-drivers.nix ]; flake.flakeModules.cuda = { imports = [ topLevel.config.flake.flakeModules.common - topLevel.config.flake.flakeModules.linkNvidiaDrivers ]; options.perSystem = flake-parts-lib.mkPerSystemOption ({ lib, pkgs, system, ... }: { config = lib.mkIf (system != "aarch64-darwin") { @@ -15,11 +13,14 @@ topLevel@{ flake-parts-lib, inputs, ... }: { nixpkgs.config.cudaSupport = true; ml-ops.common = common: { - config.LD_LIBRARY_PATH = lib.mkMerge [ - "/run/opengl-driver/lib" - # bitsandbytes need to search for CUDA libraries - "${common.config.environmentVariables.CUDA_HOME}/lib" - ]; + config.devenvShellModule.enterShell = '' + export LD_LIBRARY_PATH="$(${ + lib.escapeShellArgs [ + "${inputs.nix-gl-host.defaultPackage.${system}}/bin/nixglhost" + "--print-ld-library-path" + ] + })":''${LD_LIBRARY_PATH:-} + ''; config.devenvShellModule.packages = [ common.config.cuda.home ]; diff --git a/flake-modules/link-nvidia-drivers.nix b/flake-modules/link-nvidia-drivers.nix deleted file mode 100644 index 23f66dd..0000000 --- a/flake-modules/link-nvidia-drivers.nix +++ /dev/null @@ -1,57 +0,0 @@ -topLevel@{ flake-parts-lib, inputs, ... }: { - imports = [ - ./common.nix - inputs.flake-parts.flakeModules.flakeModules - ]; - flake.flakeModules.linkNvidiaDrivers = { - imports = [ - topLevel.config.flake.flakeModules.common - ]; - options.perSystem = flake-parts-lib.mkPerSystemOption ({ lib, config, pkgs, options, ... }: { - ml-ops.common = lib.attrsets.optionalAttrs pkgs.stdenv.isLinux { - devenvShellModule.scripts.link-nvidia-drivers.exec = '' - # Create the symbolic links to drivers when running the container - # with `nvidia-docker --gpus=all` - mkdir -p /run/opengl-driver/lib - - ( - shopt -s nullglob - - # Note that nvidia-docker will mounts drivers to different - # directories according to the Linux distribution. For example, - # on Alpine Linux it's under /usr/lib64/ and on Debian it's - # under /usr/lib/x86_64-linux-gnu. - if \ - ! ln -sf -t /run/opengl-driver/lib \ - /usr/lib/x86_64-linux-gnu/libnvidia* \ - /usr/lib/x86_64-linux-gnu/libcuda* \ - /usr/lib64/libnvidia* \ - /usr/lib64/libcuda* \ - /usr/local/nvidia/lib64/libnvidia* \ - /usr/local/nvidia/lib64/libcuda* - then - >&2 echo "Cannot find Nvidia drivers when trying to create symbolic links to Nvidia drivers" - fi - ) - ''; - devenvShellModule.enterShell = '' - # Link drivers if the script is either known to be in a container, or - # it cannot detect if it is in a container - if - ( - ! command -v systemd-detect-virt > /dev/null && - >&2 echo "Trying to link Nvidia drivers to /run/opengl-driver/lib because we cannot detect if we are in a devcontainer" - ) || ( - # Clear LD_LIBRARY_PATH previously set by devenv because it could result in error in systemd-detect-virt - LD_LIBRARY_PATH= systemd-detect-virt --container > /dev/null && - >&2 echo "Trying to link Nvidia drivers to /run/opengl-driver/lib because we are in a devcontainer..." - ) - then - link-nvidia-drivers - fi - ''; - }; - - }); - }; -} diff --git a/flake.lock b/flake.lock index 269d280..daf9780 100644 --- a/flake.lock +++ b/flake.lock @@ -368,6 +368,26 @@ "type": "github" } }, + "nix-gl-host": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1722650779, + "narHash": "sha256-PgVOPj7dQcYBE3Ab0axqxDHUoMtaRxYWDKzHzHBn/s0=", + "owner": "Atry", + "repo": "nix-gl-host", + "rev": "3f99656f3ddc1f81b368628873f48b2dd1a4c8ad", + "type": "github" + }, + "original": { + "owner": "Atry", + "repo": "nix-gl-host", + "type": "github" + } + }, "nix-ld-rs": { "inputs": { "flake-compat": "flake-compat_2", @@ -732,6 +752,7 @@ "flake-parts": "flake-parts", "mach-nix": "mach-nix", "mk-shell-bin": "mk-shell-bin", + "nix-gl-host": "nix-gl-host", "nix-ld-rs": "nix-ld-rs", "nix2container": "nix2container", "nixago": "nixago", diff --git a/flake.nix b/flake.nix index 0be6414..b2f9661 100644 --- a/flake.nix +++ b/flake.nix @@ -50,6 +50,10 @@ inputs.nixpkgs.follows = "nixpkgs"; }; nix-ld-rs.url = "github:nix-community/nix-ld-rs"; + nix-gl-host = { + url = "github:Atry/nix-gl-host"; + inputs.nixpkgs.follows = "nixpkgs"; + }; }; outputs = inputs: let