|
| 1 | +import os |
| 2 | +import re |
| 3 | +import sys |
| 4 | +import shutil |
| 5 | +import argparse |
| 6 | +import subprocess |
| 7 | +from setup_windows import check_repo_version |
| 8 | + |
| 9 | +# Get the absolute path of the current file's directory (Kohua_SS project directory) |
| 10 | +project_directory = os.path.dirname(os.path.abspath(__file__)) |
| 11 | + |
| 12 | +# Check if the "tools" directory is present in the project_directory |
| 13 | +if "tools" in project_directory: |
| 14 | + # If the "tools" directory is present, move one level up to the parent directory |
| 15 | + project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| 16 | + |
| 17 | +# Add the project directory to the beginning of the Python search path |
| 18 | +sys.path.insert(0, project_directory) |
| 19 | + |
| 20 | +from library.custom_logging import setup_logging |
| 21 | + |
| 22 | +# Set up logging |
| 23 | +log = setup_logging() |
| 24 | + |
| 25 | + |
| 26 | +def check_torch(): |
| 27 | + # Check for nVidia toolkit or AMD toolkit |
| 28 | + if shutil.which('nvidia-smi') is not None or os.path.exists( |
| 29 | + os.path.join( |
| 30 | + os.environ.get('SystemRoot') or r'C:\Windows', |
| 31 | + 'System32', |
| 32 | + 'nvidia-smi.exe', |
| 33 | + ) |
| 34 | + ): |
| 35 | + log.info('nVidia toolkit detected') |
| 36 | + elif shutil.which('rocminfo') is not None or os.path.exists( |
| 37 | + '/opt/rocm/bin/rocminfo' |
| 38 | + ): |
| 39 | + log.info('AMD toolkit detected') |
| 40 | + else: |
| 41 | + log.info('Using CPU-only Torch') |
| 42 | + |
| 43 | + try: |
| 44 | + import torch |
| 45 | + |
| 46 | + log.info(f'Torch {torch.__version__}') |
| 47 | + |
| 48 | + # Check if CUDA is available |
| 49 | + if not torch.cuda.is_available(): |
| 50 | + log.warning('Torch reports CUDA not available') |
| 51 | + else: |
| 52 | + if torch.version.cuda: |
| 53 | + # Log nVidia CUDA and cuDNN versions |
| 54 | + log.info( |
| 55 | + f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' |
| 56 | + ) |
| 57 | + elif torch.version.hip: |
| 58 | + # Log AMD ROCm HIP version |
| 59 | + log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') |
| 60 | + else: |
| 61 | + log.warning('Unknown Torch backend') |
| 62 | + |
| 63 | + # Log information about detected GPUs |
| 64 | + for device in [ |
| 65 | + torch.cuda.device(i) for i in range(torch.cuda.device_count()) |
| 66 | + ]: |
| 67 | + log.info( |
| 68 | + f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' |
| 69 | + ) |
| 70 | + return int(torch.__version__[0]) |
| 71 | + except Exception as e: |
| 72 | + log.error(f'Could not load torch: {e}') |
| 73 | + sys.exit(1) |
| 74 | + |
| 75 | + |
| 76 | +def install_requirements(requirements_file): |
| 77 | + log.info('Verifying requirements') |
| 78 | + subprocess.run(f'"{sys.executable}" -m pip install -U -r "{requirements_file}"', shell=True, check=False, env=os.environ) |
| 79 | + |
| 80 | + |
| 81 | +def main(): |
| 82 | + check_repo_version() |
| 83 | + # Parse command line arguments |
| 84 | + parser = argparse.ArgumentParser( |
| 85 | + description='Validate that requirements are satisfied.' |
| 86 | + ) |
| 87 | + parser.add_argument( |
| 88 | + '-r', |
| 89 | + '--requirements', |
| 90 | + type=str, |
| 91 | + help='Path to the requirements file.', |
| 92 | + ) |
| 93 | + parser.add_argument('--debug', action='store_true', help='Debug on') |
| 94 | + args = parser.parse_args() |
| 95 | + |
| 96 | + install_requirements(args.requirements) |
| 97 | + |
| 98 | + |
| 99 | +if __name__ == '__main__': |
| 100 | + main() |
0 commit comments