diff --git a/Dockerfile b/Dockerfile index de030f85..8c2efa85 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1.7-labs -FROM nvcr.io/nvidia/pytorch:24.07-py3 +FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ @@ -20,9 +20,9 @@ RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/to /usr/local \ /usr/local/bin \ /usr/local/lib \ - /usr/local/lib/python3.10 \ - /usr/local/lib/python3.10/dist-packages \ - /usr/local/lib/python3.10/dist-packages/__pycache__ + /usr/local/lib/python3.12 \ + /usr/local/lib/python3.12/dist-packages \ + /usr/local/lib/python3.12/dist-packages/__pycache__ # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ diff --git a/README.md b/README.md index 9da114bb..d02e7f95 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@
-Fast-LLM Logo +Fast-LLM [![Docker][ci-badge]][ci-workflow] [![Documentation][docs-badge]][docs-workflow] diff --git a/docs/assets/images/logo.png b/docs/assets/images/logo.png deleted file mode 100644 index 6141c4dd..00000000 Binary files a/docs/assets/images/logo.png and /dev/null differ diff --git a/docs/assets/images/logo.svg b/docs/assets/images/logo.svg new file mode 100644 index 00000000..bd534f63 --- /dev/null +++ b/docs/assets/images/logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/fast_llm/__init__.py b/fast_llm/__init__.py index 3dc1f76b..d3ec452c 100644 --- a/fast_llm/__init__.py +++ b/fast_llm/__init__.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index db8188d7..ac01d362 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -25,6 +25,9 @@ from fast_llm.tensor import param_get_and_unset_is_zero from triton import language as tl +# Triton requires global variables to be annotated with `tl.constexpr`. +_TritonActivationType: tl.constexpr = ActivationType + @triton.jit def triton_mlp_activation_forward_kernel( @@ -47,15 +50,15 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == ActivationType.gelu: + if activation_type == _TritonActivationType.gelu.value: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) - elif activation_type == ActivationType.silu: + elif activation_type == _TritonActivationType.silu.value: out = input_ / (1 + tl.exp(-input_)) - elif activation_type == ActivationType.relu: + elif activation_type == _TritonActivationType.relu.value: out = tl.where(input_ > 0, input_, 0) - elif activation_type == ActivationType.squared_relu: + elif activation_type == _TritonActivationType.squared_relu: relu_out = tl.where(input_ > 0, input_, 0) out = relu_out * relu_out else: @@ -95,23 +98,23 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type == ActivationType.gelu: + if activation_type == _TritonActivationType.gelu: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) if gated or recompute: out = input_ * 0.5 * (1.0 + tanh) - elif activation_type == ActivationType.silu: + elif activation_type == _TritonActivationType.silu: exp = tl.exp(-input_) sigma = 1 / (1 + exp) grad = sigma * sigma + (1 + input_) / (2 + exp + 1 / exp) if gated or recompute: out = input_ * sigma - elif activation_type == ActivationType.relu: + elif activation_type == _TritonActivationType.relu: grad = tl.where(input_ > 0, 1, 0) if gated or recompute: out = tl.where(input_ > 0, input_, 0) - elif activation_type == ActivationType.squared_relu: + elif activation_type == _TritonActivationType.squared_relu: relu_out = tl.where(input_ > 0, input_, 0) grad = 2 * relu_out if gated or recompute: @@ -148,7 +151,7 @@ def triton_mlp_activation_forward( input_, output, gated=gated, # noqa - activation_type=activation_type, # noqa + activation_type=activation_type.value, # noqa n_cols=n_cols, # noqa block_size=TritonConfig.POINTWISE_BLOCK_SIZE, ) diff --git a/mkdocs.yaml b/mkdocs.yaml index 4a137fcf..eaec87d4 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -18,8 +18,8 @@ copyright: Copyright 2024 ServiceNow, Inc. theme: name: material custom_dir: docs/overrides - logo: assets/images/logo.png - favicon: assets/images/logo.png + logo: assets/images/logo.svg + favicon: assets/images/logo.svg icon: repo: fontawesome/brands/github features: @@ -58,15 +58,15 @@ theme: name: Switch to light mode - media: "(prefers-color-scheme: light)" scheme: default - primary: indigo - accent: indigo + primary: white + accent: white toggle: icon: material/toggle-switch name: Switch to dark mode - media: "(prefers-color-scheme: dark)" scheme: slate primary: black - accent: indigo + accent: white toggle: icon: material/toggle-switch-off name: Switch to system preference diff --git a/setup.cfg b/setup.cfg index 5429dc91..95ec3b69 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [metadata] name = fast_llm # TODO: Take from __init__.py instead? -version = 0.1.0 +version = 0.2.0 [options] packages = find_namespace: @@ -25,7 +25,7 @@ CORE = # Used for checkpoints safetensors>=0.4.4 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation - flash-attn==2.6.3 + flash-attn==2.7.2.post1 # Required for some optional features and tools. OPTIONAL = @@ -45,6 +45,8 @@ OPTIONAL = DEV = pytest>=8.3.2 pytest-depends>=1.0.1 + # Somehow needed for Megatron to work with base image 24.11 + setuptools>=75.6.0 # Required for building the documentation DOCS =