Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework CUDA/native-library setup and diagnostics #1041

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import cuda_setup, research, utils
from . import research, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
Expand All @@ -12,11 +12,8 @@
matmul_cublas,
mm_cublas,
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules

if COMPILED_WITH_CUDA:
from .optim import adam
from .optim import adam

__pdoc__ = {
"libbitsandbytes": False,
Expand All @@ -25,5 +22,3 @@
}

__version__ = "0.44.0.dev"

PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes"
108 changes: 2 additions & 106 deletions bitsandbytes/__main__.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,4 @@
import glob
import os
import sys
from warnings import warn

import torch

HEADER_WIDTH = 60


def find_dynamic_library(folder, filename):
for ext in ("so", "dll", "dylib"):
yield from glob.glob(os.path.join(folder, "**", filename + ext))


def generate_bug_report_information():
print_header("")
print_header("BUG REPORT INFORMATION")
print_header("")
print('')

path_sources = [
("ANACONDA CUDA PATHS", os.environ.get("CONDA_PREFIX")),
("/usr/local CUDA PATHS", "/usr/local"),
("CUDA PATHS", os.environ.get("CUDA_PATH")),
("WORKING DIRECTORY CUDA PATHS", os.getcwd()),
]
try:
ld_library_path = os.environ.get("LD_LIBRARY_PATH")
if ld_library_path:
for path in set(ld_library_path.strip().split(os.pathsep)):
path_sources.append((f"LD_LIBRARY_PATH {path} CUDA PATHS", path))
except Exception as e:
print(f"Could not parse LD_LIBRARY_PATH: {e}")

for name, path in path_sources:
if path and os.path.isdir(path):
print_header(name)
print(list(find_dynamic_library(path, '*cuda*')))
print("")


def print_header(
txt: str, width: int = HEADER_WIDTH, filler: str = "+"
) -> None:
txt = f" {txt} " if txt else ""
print(txt.center(width, filler))


def print_debug_info() -> None:
from . import PACKAGE_GITHUB_URL
print(
"\nAbove we output some debug information. Please provide this info when "
f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n"
)


def main():
generate_bug_report_information()

from . import COMPILED_WITH_CUDA
from .cuda_setup.main import get_compute_capabilities

print_header("OTHER")
print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}")
print_header("")
print_header("DEBUG INFO END")
print_header("")
print("Checking that the library is importable and CUDA is callable...")
print("\nWARNING: Please be sure to sanitize sensitive info from any such env vars!\n")

try:
from bitsandbytes.optim import Adam

p = torch.nn.Parameter(torch.rand(10, 10).cuda())
a = torch.rand(10, 10).cuda()

p1 = p.data.sum().item()

adam = Adam([p])

out = a * p
loss = out.sum()
loss.backward()
adam.step()

p2 = p.data.sum().item()

assert p1 != p2
print("SUCCESS!")
print("Installation was successful!")
except ImportError:
print()
warn(
f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!"
)
print_debug_info()
except Exception as e:
print(e)
print_debug_info()
sys.exit(1)


if __name__ == "__main__":
from bitsandbytes.diagnostics.main import main

main()
149 changes: 117 additions & 32 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,124 @@
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)

evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""

import ctypes as ct
from warnings import warn
import logging
import os
from pathlib import Path

import torch

from bitsandbytes.cuda_setup.main import CUDASetup
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs

logger = logging.getLogger(__name__)


def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
"""
Get the disk path to the CUDA BNB native library specified by the
given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable.

The library is not guaranteed to exist at the returned path.
"""
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}"
if not cuda_specs.has_cublaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
library_name += "_nocublaslt"
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"

override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value:
library_name_stem, _, library_name_ext = library_name.rpartition(".")
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`;
# let's remove any trailing numbers:
library_name_stem = library_name_stem.rstrip("0123456789")
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda`;
# let's tack the new version number and the original extension back on.
library_name = f"{library_name_stem}{override_value}.{library_name_ext}"
logger.warning(
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
"If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n"
"For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n"
)

return PACKAGE_DIR / library_name


class BNBNativeLibrary:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this refactor here, I also find very useful, much more pythonic and maintainable. thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course there's still a lot of work, but this is a significant step forward

_lib: ct.CDLL
compiled_with_cuda = False

def __init__(self, lib: ct.CDLL):
self._lib = lib

def __getattr__(self, item):
return getattr(self._lib, item)


class CudaBNBNativeLibrary(BNBNativeLibrary):
compiled_with_cuda = True

def __init__(self, lib: ct.CDLL):
super().__init__(lib)
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p


def get_native_library() -> BNBNativeLibrary:
binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}"
cuda_specs = get_cuda_specs()
if cuda_specs:
cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)
if cuda_binary_path.exists():
binary_path = cuda_binary_path
else:
logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path)
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
dll = ct.cdll.LoadLibrary(str(binary_path))

if hasattr(dll, "get_context"): # only a CUDA-built library exposes this
return CudaBNBNativeLibrary(dll)

logger.warning(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable."
)
return BNBNativeLibrary(dll)

setup = CUDASetup.get_instance()
if setup.initialized != True:
setup.run_cuda_setup()

lib = setup.lib
try:
if lib is None and torch.cuda.is_available():
CUDASetup.get_instance().generate_instructions()
CUDASetup.get_instance().print_log_stack()
raise RuntimeError('''
CUDA Setup failed despite GPU being available. Please run the following command to get more information:

python -m bitsandbytes

Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
_ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError as ex:
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False
print(str(ex))


# print the setup details after checking for errors so we do not print twice
#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
#setup.print_log_stack()
lib = get_native_library()
except Exception as e:
lib = None
logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True)
if torch.cuda.is_available():
logger.warning(
"""
CUDA Setup failed despite CUDA being available. Please run the following command to get more information:

python -m bitsandbytes

Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
"""
)
12 changes: 12 additions & 0 deletions bitsandbytes/consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pathlib import Path
import platform

DYNAMIC_LIBRARY_SUFFIX = {
"Darwin": ".dylib",
"Linux": ".so",
"Windows": ".dll",
}.get(platform.system(), ".so")

PACKAGE_DIR = Path(__file__).parent
PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes"
NONPYTORCH_DOC_URL = "https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx"
Comment on lines +1 to +12
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much cleaner like this, nice! also especially like the .get(platform.system(), ".so")

53 changes: 0 additions & 53 deletions bitsandbytes/cuda_setup/env_vars.py

This file was deleted.

Loading
Loading