Skip to content

Commit

Permalink
nix : add cuda, use a symlinked toolkit for cmake (ggerganov#3202)
Browse files Browse the repository at this point in the history
  • Loading branch information
Green-Sky authored Sep 25, 2023
1 parent c091cdf commit a98b163
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@
);
pkgs = import nixpkgs { inherit system; };
nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
cudatoolkit_joined = with pkgs; symlinkJoin {
# HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
# see https://github.com/NixOS/nixpkgs/issues/224291
# copied from jaxlib
name = "${cudaPackages.cudatoolkit.name}-merged";
paths = [
cudaPackages.cudatoolkit.lib
cudaPackages.cudatoolkit.out
] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
# for some reason some of the required libs are in the targets/x86_64-linux
# directory; not sure why but this works around it
"${cudaPackages.cudatoolkit}/targets/${system}"
];
};
llama-python =
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
postPatch = ''
Expand Down Expand Up @@ -70,6 +84,13 @@
"-DLLAMA_CLBLAST=ON"
];
};
packages.cuda = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall;
buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
cmakeFlags = cmakeFlags ++ [
"-DLLAMA_CUBLAS=ON"
];
};
packages.rocm = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall;
buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];
Expand Down

0 comments on commit a98b163

Please sign in to comment.