Skip to content

Commit

Permalink
v0.2 (ServiceNow#90)
Browse files Browse the repository at this point in the history
Co-authored-by: Torsten Scholak <torsten.scholak@googlemail.com>
  • Loading branch information
jlamypoirier and tscholak authored Dec 17, 2024
1 parent d8f3390 commit f4a0cdd
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 22 deletions.
8 changes: 4 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1.7-labs
FROM nvcr.io/nvidia/pytorch:24.07-py3
FROM nvcr.io/nvidia/pytorch:24.11-py3

# Install dependencies.
RUN apt-get update \
Expand All @@ -20,9 +20,9 @@ RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/to
/usr/local \
/usr/local/bin \
/usr/local/lib \
/usr/local/lib/python3.10 \
/usr/local/lib/python3.10/dist-packages \
/usr/local/lib/python3.10/dist-packages/__pycache__
/usr/local/lib/python3.12 \
/usr/local/lib/python3.12/dist-packages \
/usr/local/lib/python3.12/dist-packages/__pycache__

# Copy dependency files with universal write permissions for all users.
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<div align="center" style="margin-bottom: 1em;">

<img width=50% src="docs/assets/images/logo.png" alt="Fast-LLM Logo"></img>
<img width=50% src="docs/assets/images/logo.svg" alt="Fast-LLM"></img>

[![Docker][ci-badge]][ci-workflow]
[![Documentation][docs-badge]][docs-workflow]
Expand Down
Binary file removed docs/assets/images/logo.png
Binary file not shown.
1 change: 1 addition & 0 deletions docs/assets/images/logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion fast_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.2.0"
21 changes: 12 additions & 9 deletions fast_llm/functional/triton/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from fast_llm.tensor import param_get_and_unset_is_zero
from triton import language as tl

# Triton requires global variables to be annotated with `tl.constexpr`.
_TritonActivationType: tl.constexpr = ActivationType


@triton.jit
def triton_mlp_activation_forward_kernel(
Expand All @@ -47,15 +50,15 @@ def triton_mlp_activation_forward_kernel(

input_ = tl.load(input_ptr, mask=mask).to(tl.float32)

if activation_type == ActivationType.gelu:
if activation_type == _TritonActivationType.gelu.value:
tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_)
tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input))
out = input_ * 0.5 * (1.0 + tanh)
elif activation_type == ActivationType.silu:
elif activation_type == _TritonActivationType.silu.value:
out = input_ / (1 + tl.exp(-input_))
elif activation_type == ActivationType.relu:
elif activation_type == _TritonActivationType.relu.value:
out = tl.where(input_ > 0, input_, 0)
elif activation_type == ActivationType.squared_relu:
elif activation_type == _TritonActivationType.squared_relu:
relu_out = tl.where(input_ > 0, input_, 0)
out = relu_out * relu_out
else:
Expand Down Expand Up @@ -95,23 +98,23 @@ def triton_mlp_activation_backward_kernel(
input_ = tl.load(input_ptr, mask=mask).to(tl.float32)
output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32)

if activation_type == ActivationType.gelu:
if activation_type == _TritonActivationType.gelu:
tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_)
tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input))
grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh)
if gated or recompute:
out = input_ * 0.5 * (1.0 + tanh)
elif activation_type == ActivationType.silu:
elif activation_type == _TritonActivationType.silu:
exp = tl.exp(-input_)
sigma = 1 / (1 + exp)
grad = sigma * sigma + (1 + input_) / (2 + exp + 1 / exp)
if gated or recompute:
out = input_ * sigma
elif activation_type == ActivationType.relu:
elif activation_type == _TritonActivationType.relu:
grad = tl.where(input_ > 0, 1, 0)
if gated or recompute:
out = tl.where(input_ > 0, input_, 0)
elif activation_type == ActivationType.squared_relu:
elif activation_type == _TritonActivationType.squared_relu:
relu_out = tl.where(input_ > 0, input_, 0)
grad = 2 * relu_out
if gated or recompute:
Expand Down Expand Up @@ -148,7 +151,7 @@ def triton_mlp_activation_forward(
input_,
output,
gated=gated, # noqa
activation_type=activation_type, # noqa
activation_type=activation_type.value, # noqa
n_cols=n_cols, # noqa
block_size=TritonConfig.POINTWISE_BLOCK_SIZE,
)
Expand Down
10 changes: 5 additions & 5 deletions mkdocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ copyright: Copyright 2024 ServiceNow, Inc.
theme:
name: material
custom_dir: docs/overrides
logo: assets/images/logo.png
favicon: assets/images/logo.png
logo: assets/images/logo.svg
favicon: assets/images/logo.svg
icon:
repo: fontawesome/brands/github
features:
Expand Down Expand Up @@ -58,15 +58,15 @@ theme:
name: Switch to light mode
- media: "(prefers-color-scheme: light)"
scheme: default
primary: indigo
accent: indigo
primary: white
accent: white
toggle:
icon: material/toggle-switch
name: Switch to dark mode
- media: "(prefers-color-scheme: dark)"
scheme: slate
primary: black
accent: indigo
accent: white
toggle:
icon: material/toggle-switch-off
name: Switch to system preference
Expand Down
6 changes: 4 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[metadata]
name = fast_llm
# TODO: Take from __init__.py instead?
version = 0.1.0
version = 0.2.0

[options]
packages = find_namespace:
Expand All @@ -25,7 +25,7 @@ CORE =
# Used for checkpoints
safetensors>=0.4.4
# Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation
flash-attn==2.6.3
flash-attn==2.7.2.post1

# Required for some optional features and tools.
OPTIONAL =
Expand All @@ -45,6 +45,8 @@ OPTIONAL =
DEV =
pytest>=8.3.2
pytest-depends>=1.0.1
# Somehow needed for Megatron to work with base image 24.11
setuptools>=75.6.0

# Required for building the documentation
DOCS =
Expand Down

0 comments on commit f4a0cdd

Please sign in to comment.