diff --git a/Dockerfile b/Dockerfile
index de030f85..8c2efa85 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1.7-labs
-FROM nvcr.io/nvidia/pytorch:24.07-py3
+FROM nvcr.io/nvidia/pytorch:24.11-py3
# Install dependencies.
RUN apt-get update \
@@ -20,9 +20,9 @@ RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/to
/usr/local \
/usr/local/bin \
/usr/local/lib \
- /usr/local/lib/python3.10 \
- /usr/local/lib/python3.10/dist-packages \
- /usr/local/lib/python3.10/dist-packages/__pycache__
+ /usr/local/lib/python3.12 \
+ /usr/local/lib/python3.12/dist-packages \
+ /usr/local/lib/python3.12/dist-packages/__pycache__
# Copy dependency files with universal write permissions for all users.
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
diff --git a/README.md b/README.md
index 9da114bb..d02e7f95 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
-

+

[![Docker][ci-badge]][ci-workflow]
[![Documentation][docs-badge]][docs-workflow]
diff --git a/docs/assets/images/logo.png b/docs/assets/images/logo.png
deleted file mode 100644
index 6141c4dd..00000000
Binary files a/docs/assets/images/logo.png and /dev/null differ
diff --git a/docs/assets/images/logo.svg b/docs/assets/images/logo.svg
new file mode 100644
index 00000000..bd534f63
--- /dev/null
+++ b/docs/assets/images/logo.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/fast_llm/__init__.py b/fast_llm/__init__.py
index 3dc1f76b..d3ec452c 100644
--- a/fast_llm/__init__.py
+++ b/fast_llm/__init__.py
@@ -1 +1 @@
-__version__ = "0.1.0"
+__version__ = "0.2.0"
diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py
index db8188d7..ac01d362 100644
--- a/fast_llm/functional/triton/mlp.py
+++ b/fast_llm/functional/triton/mlp.py
@@ -25,6 +25,9 @@
from fast_llm.tensor import param_get_and_unset_is_zero
from triton import language as tl
+# Triton requires global variables to be annotated with `tl.constexpr`.
+_TritonActivationType: tl.constexpr = ActivationType
+
@triton.jit
def triton_mlp_activation_forward_kernel(
@@ -47,15 +50,15 @@ def triton_mlp_activation_forward_kernel(
input_ = tl.load(input_ptr, mask=mask).to(tl.float32)
- if activation_type == ActivationType.gelu:
+ if activation_type == _TritonActivationType.gelu.value:
tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_)
tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input))
out = input_ * 0.5 * (1.0 + tanh)
- elif activation_type == ActivationType.silu:
+ elif activation_type == _TritonActivationType.silu.value:
out = input_ / (1 + tl.exp(-input_))
- elif activation_type == ActivationType.relu:
+ elif activation_type == _TritonActivationType.relu.value:
out = tl.where(input_ > 0, input_, 0)
- elif activation_type == ActivationType.squared_relu:
+ elif activation_type == _TritonActivationType.squared_relu:
relu_out = tl.where(input_ > 0, input_, 0)
out = relu_out * relu_out
else:
@@ -95,23 +98,23 @@ def triton_mlp_activation_backward_kernel(
input_ = tl.load(input_ptr, mask=mask).to(tl.float32)
output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32)
- if activation_type == ActivationType.gelu:
+ if activation_type == _TritonActivationType.gelu:
tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_)
tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input))
grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh)
if gated or recompute:
out = input_ * 0.5 * (1.0 + tanh)
- elif activation_type == ActivationType.silu:
+ elif activation_type == _TritonActivationType.silu:
exp = tl.exp(-input_)
sigma = 1 / (1 + exp)
grad = sigma * sigma + (1 + input_) / (2 + exp + 1 / exp)
if gated or recompute:
out = input_ * sigma
- elif activation_type == ActivationType.relu:
+ elif activation_type == _TritonActivationType.relu:
grad = tl.where(input_ > 0, 1, 0)
if gated or recompute:
out = tl.where(input_ > 0, input_, 0)
- elif activation_type == ActivationType.squared_relu:
+ elif activation_type == _TritonActivationType.squared_relu:
relu_out = tl.where(input_ > 0, input_, 0)
grad = 2 * relu_out
if gated or recompute:
@@ -148,7 +151,7 @@ def triton_mlp_activation_forward(
input_,
output,
gated=gated, # noqa
- activation_type=activation_type, # noqa
+ activation_type=activation_type.value, # noqa
n_cols=n_cols, # noqa
block_size=TritonConfig.POINTWISE_BLOCK_SIZE,
)
diff --git a/mkdocs.yaml b/mkdocs.yaml
index 4a137fcf..eaec87d4 100644
--- a/mkdocs.yaml
+++ b/mkdocs.yaml
@@ -18,8 +18,8 @@ copyright: Copyright 2024 ServiceNow, Inc.
theme:
name: material
custom_dir: docs/overrides
- logo: assets/images/logo.png
- favicon: assets/images/logo.png
+ logo: assets/images/logo.svg
+ favicon: assets/images/logo.svg
icon:
repo: fontawesome/brands/github
features:
@@ -58,15 +58,15 @@ theme:
name: Switch to light mode
- media: "(prefers-color-scheme: light)"
scheme: default
- primary: indigo
- accent: indigo
+ primary: white
+ accent: white
toggle:
icon: material/toggle-switch
name: Switch to dark mode
- media: "(prefers-color-scheme: dark)"
scheme: slate
primary: black
- accent: indigo
+ accent: white
toggle:
icon: material/toggle-switch-off
name: Switch to system preference
diff --git a/setup.cfg b/setup.cfg
index 5429dc91..95ec3b69 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,7 +1,7 @@
[metadata]
name = fast_llm
# TODO: Take from __init__.py instead?
-version = 0.1.0
+version = 0.2.0
[options]
packages = find_namespace:
@@ -25,7 +25,7 @@ CORE =
# Used for checkpoints
safetensors>=0.4.4
# Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation
- flash-attn==2.6.3
+ flash-attn==2.7.2.post1
# Required for some optional features and tools.
OPTIONAL =
@@ -45,6 +45,8 @@ OPTIONAL =
DEV =
pytest>=8.3.2
pytest-depends>=1.0.1
+ # Somehow needed for Megatron to work with base image 24.11
+ setuptools>=75.6.0
# Required for building the documentation
DOCS =