Skip to content

Commit

Permalink
Support pip installation (#25)
Browse files Browse the repository at this point in the history
* support install from prebuilt
* fix invalid cross device link problem
* change version backto v1.0.0
* add a env variable to determine whether to use the local version
* flux is already registerd on pypi, renamed to byte_flux
* rename the package to the byte_flux
* read version from __init__.py rather than version.txt
* v1.0.1
* use local version by default
* use public version by default
  • Loading branch information
zheng-ningxin authored Jul 25, 2024
1 parent 532ad84 commit a65ee72
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 149 deletions.
117 changes: 0 additions & 117 deletions gen_version.py

This file was deleted.

18 changes: 0 additions & 18 deletions pyproject.toml

This file was deleted.

3 changes: 2 additions & 1 deletion python/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#
################################################################################
__version__ = "1.0.2"

from .cpp_mod import *
from .ag_gemm import *
Expand All @@ -24,4 +25,4 @@
from .gemm_rs_sm80 import *
from .util import *
from .dist_utils import *
from .version import __version__ as __version__

134 changes: 122 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,81 @@
import glob
import os
import shutil
import sys
import re
import ast
from pathlib import Path

import urllib
import urllib.request
import urllib.error
import setuptools
import torch
import subprocess
from torch.utils.cpp_extension import BuildExtension

from gen_version import generate_versoin_file, check_final_release
from packaging.version import parse, Version
from typing import Optional, Tuple
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

# Project directory root
root_path: Path = Path(__file__).resolve().parent

version_txt = os.path.join(root_path, "version.txt")
version_file = os.path.join(root_path, "python/flux/version.py")
is_dev = not check_final_release()
flux_version = generate_versoin_file(version_txt, version_file, dev=is_dev)
enable_nvshmem = int(os.getenv("FLUX_SHM_USE_NVSHMEM", 0))
PACKAGE_NAME = "byte_flux"
BASE_WHEEL_URL = "https://github.com/bytedance/flux/releases/download/{tag_name}/{wheel_name}"
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
USE_LOCAL_VERSION = int(os.getenv("FLUX_USE_LOCAL_VERSION", 0))

def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) by nvcc --version"""

# Try finding NVCC
nvcc_bin: Optional[Path] = None
if nvcc_bin is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME"))
nvcc_bin = cuda_home / "bin" / "nvcc"
if nvcc_bin is None:
# Check if nvcc is in path
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
nvcc_bin = Path(nvcc_bin)
if nvcc_bin is None:
# Last-ditch guess in /usr/local/cuda
cuda_home = Path("/usr/local/cuda")
nvcc_bin = cuda_home / "bin" / "nvcc"
if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")

# Query NVCC for version info
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split(".")
return tuple(int(v) for v in version)


def get_local_version(public_version):
cuda_version_major, cuda_version_minor = cuda_version()
torch_version_splits = torch.__version__.split(".")
torch_version = f"{torch_version_splits[0]}.{torch_version_splits[1]}"
version = public_version + f"+cu{cuda_version_major}{cuda_version_minor}" + f"torch{torch_version}"
return version

def get_public_version():
with open(Path(root_path) / "python" / "flux" / "__init__.py", "r") as f:
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
public_version = ast.literal_eval(version_match.group(1))
return public_version

def get_package_version():
global USE_LOCAL_VERSION
public_version = get_public_version()
if USE_LOCAL_VERSION:
return get_local_version(public_version)
return public_version

def pathlib_wrapper(func):
def wrapper(*kargs, **kwargs):
Expand All @@ -43,6 +102,9 @@ def cutlass_deps():
def read_flux_ths_files():
file_path = root_path / "build/src/ths_op/flux_ths_files.txt"
variables = {}
if not os.path.exists(file_path):
# flux is installed through pip3, the flux_ths_files.txt is not generated
return []
with open(file_path, "r") as file:
for line in file:
if "=" in line:
Expand Down Expand Up @@ -125,8 +187,55 @@ def setup_pytorch_extension() -> setuptools.Extension:
)


def get_wheel_url():
flux_tag_version = get_public_version()
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
torch_version_raw = parse(torch.__version__)
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
torch_cuda_version = parse(torch.version.cuda)
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
wheel_filename = f"{PACKAGE_NAME}-{flux_tag_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-linux_x86_64.whl"
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flux_tag_version}", wheel_name=wheel_filename)
return wheel_url, wheel_filename


class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""

def run(self):
if FORCE_BUILD:
return super().run()

wheel_url, wheel_filename = get_wheel_url()
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)

# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)

impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"

wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
shutil.move(wheel_filename, wheel_path)
except (urllib.error.HTTPError, urllib.error.URLError):
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().run()


def main():
# Submodules to install
flux_version = get_package_version()
packages = setuptools.find_packages(
where="python",
include=["flux", "flux.pynvshmem", "flux_ths_pybind"],
Expand All @@ -140,18 +249,19 @@ def main():
]
# Configure package
setuptools.setup(
name="flux",
name=PACKAGE_NAME,
version=flux_version,
package_dir={"": "python"},
packages=packages,
description="Flux library",
ext_modules=[setup_pytorch_extension()],
cmdclass={"build_ext": BuildExtension},
setup_requires=["torch", "cmake"],
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension},
setup_requires=["torch", "cmake", "packaging"],
install_requires=["torch"],
extras_require={"test": ["torch", "numpy"]},
license_files=("LICENSE",),
package_data={"python/lib": ["*.so"]}, # only works for sdist
python_requires=">=3.8",
# include_package_data=True,
data_files=[
(
Expand Down
1 change: 0 additions & 1 deletion version.txt

This file was deleted.

0 comments on commit a65ee72

Please sign in to comment.