Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
fix num_generated_tokens and drop mii (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayank31398 authored Jan 23, 2023
1 parent e970be1 commit 9d48dbf
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 503 deletions.
70 changes: 37 additions & 33 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
FROM nvidia/cuda:11.6.0-devel-ubi8 as cuda
FROM nvidia/cuda:11.6.1-devel-ubi8 as base

ENV PORT=5000

WORKDIR /src

FROM cuda as conda
RUN dnf install -y --disableplugin=subscription-manager make git && dnf clean all --disableplugin=subscription-manager

# taken form pytorch's dockerfile
RUN curl -L -o ./miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
Expand All @@ -21,47 +17,55 @@ RUN conda create -n inference python=${PYTHON_VERSION} pip -y
# change shell to activate env
SHELL ["conda", "run", "-n", "inference", "/bin/bash", "-c"]

FROM conda as conda_env
FROM base as conda

# update conda
RUN conda update -n base -c defaults conda -y
# cmake
RUN conda install -c anaconda cmake -y

# update conda
RUN conda update -n base -c defaults conda -y

# necessary stuff
RUN pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 \
transformers \
deepspeed==0.7.5 \
deepspeed-mii==0.0.2 \
accelerate \
gunicorn \
transformers==4.25.1 \
deepspeed==0.7.6 \
accelerate==0.15.0 \
gunicorn==20.1.0 \
flask \
flask_api \
pydantic \
huggingface_hub \
flask_api \
fastapi==0.89.1 \
uvicorn==0.19.0 \
jinja2==3.1.2 \
pydantic==1.10.2 \
huggingface_hub==0.10.1 \
grpcio-tools==1.50.0 \
--no-cache-dir

# copy the code
COPY inference_server inference_server
COPY Makefile Makefile
COPY LICENSE LICENSE

# install grpc and compile protos
RUN make gen-proto

# clean conda env
RUN conda clean -ya

EXPOSE ${PORT}

# change this as you like 🤗
ENV TRANSFORMERS_CACHE=/transformers_cache/ \
HUGGINGFACE_HUB_CACHE=${TRANSFORMERS_CACHE} \
HOME=/homedir
ENV TRANSFORMERS_CACHE=/cos/HF_cache \
HUGGINGFACE_HUB_CACHE=${TRANSFORMERS_CACHE}

RUN mkdir ${HOME} && chmod g+wx ${HOME} && \
mkdir tmp && chmod -R g+w tmp
FROM conda as app

# for debugging
# RUN chmod -R g+w inference_server && chmod g+w Makefile
WORKDIR /src
RUN chmod -R g+w /src

RUN mkdir /.cache && \
chmod -R g+w /.cache

CMD make bloom-176b
ENV PORT=5000 \
UI_PORT=5001
EXPOSE ${PORT}
EXPOSE ${UI_PORT}

CMD git clone https://github.com/huggingface/transformers-bloom-inference.git && \
cd transformers-bloom-inference && \
# install grpc and compile protos
make gen-proto && \
make ui && \
make bloom-560m
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
gen-proto:
pip install grpcio-tools==1.50.0 --no-cache-dir

mkdir -p inference_server/model_handler/grpc_utils/pb

python -m grpc_tools.protoc -Iinference_server/model_handler/grpc_utils/proto --python_out=inference_server/model_handler/grpc_utils/pb --grpc_python_out=inference_server/model_handler/grpc_utils/pb inference_server/model_handler/grpc_utils/proto/generation.proto
Expand Down Expand Up @@ -100,3 +98,6 @@ codegen-mono:
MAX_BATCH_SIZE=4 \
CUDA_VISIBLE_DEVICES=0 \
gunicorn -t 0 -w 1 -b 127.0.0.1:5000 inference_server.server:app --access-logfile - --access-logformat '%(h)s %(t)s "%(r)s" %(s)s %(b)s'

ui:
python -m ui &
4 changes: 0 additions & 4 deletions inference_server/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,4 @@
DS_INFERENCE = "ds_inference"
DS_ZERO = "ds_zero"

# model weights
DS_INFERENCE_BLOOM_FP16 = "microsoft/bloom-deepspeed-inference-fp16"
DS_INFERENCE_BLOOM_INT8 = "microsoft/bloom-deepspeed-inference-int8"

# GRPC_MAX_MSG_SIZE = 2**30 # 1GB
14 changes: 12 additions & 2 deletions inference_server/download_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

from .models import get_downloaded_model_path
from inference_server.models import get_hf_model_class
from transformers import AutoConfig, AutoTokenizer


def get_args() -> argparse.Namespace:
Expand All @@ -12,6 +13,12 @@ def get_args() -> argparse.Namespace:
required=True,
help="model to use",
)
parser.add_argument(
"--model_class",
type=str,
required=True,
help="model class to use",
)

args = parser.parse_args()

Expand All @@ -20,7 +27,10 @@ def get_args() -> argparse.Namespace:

def main() -> None:
args = get_args()
get_downloaded_model_path(args.model_name)
print("downloading", args.model_name)
AutoConfig.from_pretrained(args.model_name)
AutoTokenizer.from_pretrained(args.model_name)
get_hf_model_class(args.model_class).from_pretrained(args.model_name)


if __name__ == "__main__":
Expand Down
29 changes: 24 additions & 5 deletions inference_server/model_handler/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
from typing import List

import grpc
from mii.server_client import MIIServerClient
from transformers import AutoTokenizer

from ..constants import DS_INFERENCE, DS_ZERO
from ..models import get_downloaded_model_path, get_model_class, load_tokenizer
from ..models import get_model_class, load_tokenizer
from ..utils import (
GenerateResponse,
TokenizeRequest,
Expand All @@ -25,14 +23,14 @@
from .grpc_utils.pb import generation_pb2, generation_pb2_grpc


class ModelDeployment(MIIServerClient):
class ModelDeployment:
def __init__(self, args: argparse.Namespace, use_grpc_server: bool = False, cuda_visible_devices: List[int] = [0]):
self.cuda_visible_devices = cuda_visible_devices
self.num_gpus = len(self.cuda_visible_devices)
self.use_grpc_server = use_grpc_server

if self.use_grpc_server:
self.tokenizer = load_tokenizer(get_downloaded_model_path(args.model_name))
self.tokenizer = load_tokenizer(args.model_name)

self.initialize_ports()

Expand All @@ -57,6 +55,27 @@ def initialize_ports(self):
for i in range(self.num_gpus):
self.ports.append(50950 + self.cuda_visible_devices[i])

def _is_socket_open(self, port):
import socket

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(("0.0.0.0", port))
sock.close()
return result == 0

def _is_server_process_alive(self):
if self.process is None:
return True
try:
self.process.wait(1)
except subprocess.TimeoutExpired as err:
# timeout means we're still running and all (probably) okay
is_alive = True
else:
# no exception case
is_alive = False
return is_alive

def _wait_until_server_is_live(self):
sockets_open = False
while not sockets_open:
Expand Down
2 changes: 1 addition & 1 deletion inference_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..constants import DS_INFERENCE, DS_ZERO, HF_ACCELERATE
from .model import Model, get_downloaded_model_path, load_tokenizer
from .model import Model, get_hf_model_class, load_tokenizer


def get_model_class(deployment_framework: str):
Expand Down
34 changes: 24 additions & 10 deletions inference_server/models/ds_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import torch.distributed as dist

import deepspeed
from transformers import AutoConfig, AutoTokenizer
from huggingface_hub import try_to_load_from_cache
from transformers import AutoConfig

from ..utils import print_rank_n, run_rank_n
from .model import Model, get_downloaded_model_path, get_hf_model_class, load_tokenizer
from .model import Model, get_hf_model_class


# basic DeepSpeed inference model class for benchmarking
Expand All @@ -24,26 +25,23 @@ def __init__(self, args: Namespace) -> None:

world_size = int(os.getenv("WORLD_SIZE", "1"))

downloaded_model_path = get_downloaded_model_path(args.model_name)

self.tokenizer = load_tokenizer(downloaded_model_path)
self.pad = self.tokenizer.pad_token_id

# create dummy tensors for allocating space which will be filled with
# the actual weights while calling deepspeed.init_inference in the
# following code
with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
self.model = get_hf_model_class(args.model_class).from_config(
AutoConfig.from_pretrained(downloaded_model_path), torch_dtype=torch.bfloat16
AutoConfig.from_pretrained(args.model_name), torch_dtype=torch.bfloat16
)
self.model = self.model.eval()

downloaded_model_path = get_model_path(args.model_name)

if args.dtype in [torch.float16, torch.int8]:
# We currently support the weights provided by microsoft (which are
# pre-sharded)
if args.use_pre_sharded_checkpoints:
checkpoints_json = os.path.join(downloaded_model_path, "ds_inference_config.json")
checkpoints_json = os.path.join(downloaded_model_path, "ds_inference_config.json")

if os.path.isfile(checkpoints_json):
self.model = deepspeed.init_inference(
self.model,
mp_size=world_size,
Expand All @@ -60,6 +58,7 @@ def __init__(self, args: Namespace) -> None:
self.model = deepspeed.init_inference(
self.model,
mp_size=world_size,
base_dir=downloaded_model_path,
dtype=args.dtype,
checkpoint=checkpoints_json,
replace_with_kernel_inject=True,
Expand All @@ -74,6 +73,8 @@ def __init__(self, args: Namespace) -> None:
print_rank_n("Model loaded")
dist.barrier()

self.post_init(args.model_name)


class TemporaryCheckpointsJSON:
def __init__(self, model_path: str):
Expand All @@ -93,3 +94,16 @@ def __enter__(self):

def __exit__(self, type, value, traceback):
return


def get_model_path(model_name: str):
config_file = "config.json"

# will fall back to HUGGINGFACE_HUB_CACHE
config_path = try_to_load_from_cache(model_name, config_file, cache_dir=os.getenv("TRANSFORMERS_CACHE"))

if config_path is not None:
return os.path.dirname(config_path)
# treat the model name as an explicit model path
elif os.path.isfile(os.path.join(model_name, config_file)):
return model_name
17 changes: 6 additions & 11 deletions inference_server/models/ds_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import torch.distributed as dist

import deepspeed
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig
from transformers.deepspeed import HfDeepSpeedConfig

from ..utils import print_rank_n
from .model import Model, get_downloaded_model_path, get_hf_model_class, load_tokenizer
from .model import Model, get_hf_model_class


class DSZeROModel(Model):
Expand All @@ -18,9 +18,7 @@ def __init__(self, args: Namespace) -> None:

super().__init__(args)

downloaded_model_path = get_downloaded_model_path(args.model_name)

config = AutoConfig.from_pretrained(downloaded_model_path)
config = AutoConfig.from_pretrained(args.model_name)

world_size = int(os.getenv("WORLD_SIZE", "1"))
train_batch_size = 1 * world_size
Expand Down Expand Up @@ -54,12 +52,7 @@ def __init__(self, args: Namespace) -> None:
# this tells from_pretrained to instantiate directly on gpus
dschf = HfDeepSpeedConfig(ds_config)

self.tokenizer = load_tokenizer(downloaded_model_path)
self.pad = self.tokenizer.pad_token_id

self.model = get_hf_model_class(args.model_class).from_pretrained(
downloaded_model_path, torch_dtype=args.dtype
)
self.model = get_hf_model_class(args.model_class).from_pretrained(args.model_name, torch_dtype=args.dtype)
self.model = self.model.eval()

# convert model to a fully sharded model using ZeRO
Expand All @@ -74,3 +67,5 @@ def __init__(self, args: Namespace) -> None:

print_rank_n("Model loaded")
dist.barrier()

self.post_init(args.model_name)
13 changes: 4 additions & 9 deletions inference_server/models/hf_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer

from ..utils import print_rank_n
from .model import Model, get_downloaded_model_path, get_hf_model_class, load_tokenizer
from .model import Model, get_hf_model_class


class HFAccelerateModel(Model):
Expand All @@ -14,12 +12,7 @@ def __init__(self, args: Namespace) -> None:

super().__init__(args)

downloaded_model_path = get_downloaded_model_path(args.model_name)

self.tokenizer = load_tokenizer(downloaded_model_path)
self.pad = self.tokenizer.pad_token_id

kwargs = {"pretrained_model_name_or_path": downloaded_model_path, "device_map": "auto"}
kwargs = {"pretrained_model_name_or_path": args.model_name, "device_map": "auto"}

if len(args.cuda_visible_devices) > 1:
kwargs["device_map"] = "balanced_low_0"
Expand All @@ -39,3 +32,5 @@ def __init__(self, args: Namespace) -> None:
self.input_device = "cuda:0"

print_rank_n("Model loaded")

self.post_init(args.model_name)
Loading

0 comments on commit 9d48dbf

Please sign in to comment.