Skip to content

Commit

Permalink
feat(server): flash neoX (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Mar 24, 2023
1 parent 23e1028 commit 05e9a79
Show file tree
Hide file tree
Showing 10 changed files with 1,307 additions and 25 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ on:
branches:
- 'main'

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
build-and-push-image:
runs-on: ubuntu-latest
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ on:
- "Cargo.lock"
- "rust-toolchain.toml"

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
run_tests:
runs-on: ubuntu-20.04
Expand Down
9 changes: 6 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ENV LANG=C.UTF-8 \
CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin

RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y git curl libssl-dev && rm -rf /var/lib/apt/lists/*

RUN cd ~ && \
curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
Expand All @@ -53,10 +53,13 @@ RUN cd ~ && \

WORKDIR /usr/src

# Install torch
RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir

COPY server/Makefile server/Makefile

# Install specific version of torch
RUN cd server && make install-torch
# Install specific version of flash attention
RUN cd server && make install-flash-attention

# Install specific version of transformers
RUN cd server && BUILD_EXTENSIONS="True" make install-transformers
Expand Down
17 changes: 12 additions & 5 deletions server/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
flash_att_commit := 4d87e4d875077ad9efd25030efa4ab0ba92c19e1

gen-server:
# Compile protos
Expand All @@ -12,13 +13,19 @@ install-transformers:
# Install specific version of transformers with custom cuda kernels
pip uninstall transformers -y || true
rm -rf transformers || true
rm -rf transformers-$(transformers_commit) || true
curl -L -O https://github.com/OlivierDehaene/transformers/archive/$(transformers_commit).zip
unzip $(transformers_commit).zip
rm $(transformers_commit).zip
mv transformers-$(transformers_commit) transformers
git clone https://github.com/OlivierDehaene/transformers.git
cd transformers && git checkout $(transformers_commit)
cd transformers && python setup.py install

install-flash-attention:
# Install specific version of flash attention
pip install packaging
pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true
rm -rf flash-attention || true
git clone https://github.com/HazyResearch/flash-attention.git
cd flash-attention && git checkout $(flash_att_commit)
cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install

install-torch:
# Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
Expand Down
20 changes: 18 additions & 2 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import torch

from loguru import logger
from transformers import AutoConfig
from typing import Optional

Expand All @@ -12,6 +14,14 @@
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded

try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
FLASH_NEOX = torch.cuda.is_available() and int(os.environ.get("FLASH_NEOX", 0)) == 1
except ImportError:
if int(os.environ.get("FLASH_NEOX", 0)) == 1:
logger.exception("Could not import FlashNeoX")
FLASH_NEOX = False

__all__ = [
"Model",
"BLOOM",
Expand All @@ -26,6 +36,10 @@
"get_model",
]

if FLASH_NEOX:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded)

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
Expand Down Expand Up @@ -59,9 +73,11 @@ def get_model(

if config.model_type == "gpt_neox":
if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize)
neox_cls = FlashNeoXSharded if FLASH_NEOX else GPTNeoxSharded
return neox_cls(model_id, revision, quantize=quantize)
else:
return CausalLM(model_id, revision, quantize=quantize)
neox_cls = FlashNeoX if FLASH_NEOX else CausalLM
return neox_cls(model_id, revision, quantize=quantize)

if config.model_type == "t5":
if sharded:
Expand Down
1 change: 0 additions & 1 deletion server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def from_pb(
inputs = []
next_token_choosers = []
stopping_criterias = []
input_lengths = []

# Parse batch
padding_right_offset = 0
Expand Down
Loading

0 comments on commit 05e9a79

Please sign in to comment.