Skip to content

Commit

Permalink
Exchange Llama2 against Llama3 in HuggingFace_accelerate example (#3108)
Browse files Browse the repository at this point in the history
* Rename llama2 to llama in HF_accelerate

* Replace llama2 with llama3 in large_models/Huggingface_accelerate

* Remove code handler

* Fix lint error
  • Loading branch information
mreso authored Apr 24, 2024
1 parent 41696c0 commit 6ae146d
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ def hf_model(model_str):
revision=args.revision,
cache_dir=args.model_path,
use_auth_token=True,
ignore_patterns=["original/*"],
)
print(f"Files for '{args.model_name}' is downloaded to '{snapshot_path}'")
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Loading meta-llama/Llama-2-70b-chat-hf on AWS EC2 g5.24xlarge using accelerate
# Loading meta-llama/Meta-Llama-3-70B-Instruct on AWS EC2 g5.24xlarge using accelerate

This document briefs on serving large HG models with limited resource using accelerate. This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint).
This document briefs on serving large HF models with limited resource using accelerate. This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). This examples uses Meta Llama-3 as an example but it works with Llama2 as well by replacing the model identifier.

### Step 1: Download model Permission

Follow [this instruction](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) to get permission
Follow [this instruction](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) to get permission

Login with a Hugging Face account
```
Expand All @@ -14,44 +14,44 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../Download_model.py --model_path model --model_name meta-llama/Llama-2-70b-chat-hf
python ../Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-70B-Instruct
```
Model will be saved in the following path, `model/models--meta-llama--Llama-2-70b-chat-hf`.
Model will be saved in the following path, `model/models--meta-llama--Meta-Llama-3-70B-Instruct`.

### Step 2: Generate MAR file

Add the downloaded path to " model_path:" in `model-config.yaml` and run the following.

```bash
torch-model-archiver --model-name llama2-70b-chat --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive
torch-model-archiver --model-name llama3-70b-instruct --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive
```

If you are using conda, and notice issues with mpi4py, you would need to install openmpi-mpicc using the following
If you are using conda, and notice issues with mpi4py, you can install it with

```
conda install -c conda-forge openmpi-mpicc
conda install mpi4py
```

### Step 3: Add the mar file to model store

```bash
mkdir model_store
mv llama2-70b-chat model_store
mv model model_store/llama2-70b-chat
mv llama3-70b-instruct model_store
mv model model_store/llama3-70b-instruct
```

### Step 3: Start torchserve

Update config.properties and start torchserve

```bash
torchserve --start --ncs --ts-config config.properties --model-store model_store --models llama2-70b-chat
torchserve --start --ncs --ts-config config.properties --model-store model_store --models llama3-70b-instruct
```

### Step 4: Run inference

```bash
curl -v "http://localhost:8080/predictions/llama2-70b-chat" -T sample_text.txt
curl -v "http://localhost:8080/predictions/llama3-70b-instruct" -T sample_text.txt
```

results in the following output
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from abc import ABC
from typing import Dict

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from ts.context import Context
from ts.torch_handler.base_handler import BaseHandler
Expand Down Expand Up @@ -39,26 +40,30 @@ def initialize(self, ctx: Context):
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
torch.manual_seed(seed)

logger.info("Model %s loading tokenizer", ctx.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
logger.info("Model %s loaded tokenizer successfully", ctx.model_name)

if self.tokenizer.vocab_size >= 128000:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
else:
quant_config = BitsAndBytesConfig(load_in_8bit=True)

self.model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="balanced",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_8bit=True,
quantization_config=quant_config,
trust_remote_code=True,
)
if ctx.model_yaml_config["handler"]["fast_kernels"]:
from optimum.bettertransformer import BetterTransformer

try:
self.model = BetterTransformer.transform(self.model)
except RuntimeError as error:
logger.warning(
"HuggingFace Optimum is not supporting this model,for the list of supported models, please refer to this doc,https://huggingface.co/docs/optimum/bettertransformer/overview"
)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)

self.device = next(iter(self.model.parameters())).device
logger.info("Model %s loaded successfully", ctx.model_name)
self.initialized = True

Expand All @@ -72,38 +77,31 @@ def preprocess(self, requests):
tuple: A tuple with two tensors: the batch of input ids and the batch of
attention masks.
"""
input_texts = [data.get("data") or data.get("body") for data in requests]
input_ids_batch, attention_mask_batch = [], []
for input_text in input_texts:
input_ids, attention_mask = self.encode_input_text(input_text)
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.model.device)
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device)
return input_ids_batch, attention_mask_batch

def encode_input_text(self, input_text):
input_texts = [self.preprocess_requests(r) for r in requests]

logger.info("Received texts: '%s'", input_texts)
inputs = self.tokenizer(
input_texts,
max_length=self.max_length,
padding=True,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
).to(self.device)
return inputs

def preprocess_requests(self, request: Dict):
"""
Encodes a single input text using the tokenizer.
Preprocess request
Args:
input_text (str): The input text to be encoded.
request (Dict): Request to be decoded.
Returns:
tuple: A tuple with two tensors: the encoded input ids and the attention mask.
str: Decoded input text
"""
input_text = request.get("data") or request.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
logger.info("Received text: '%s'", input_text)
inputs = self.tokenizer.encode_plus(
input_text,
max_length=self.max_length,
padding=False,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
return input_ids, attention_mask
return input_text

def inference(self, input_batch):
"""
Expand All @@ -115,11 +113,8 @@ def inference(self, input_batch):
Returns:
list: A list of strings with the predicted values for each input text in the batch.
"""
input_ids_batch, attention_mask_batch = input_batch
input_ids_batch = input_ids_batch.to(self.device)
outputs = self.model.generate(
input_ids_batch,
attention_mask=attention_mask_batch,
**input_batch,
max_length=self.max_new_tokens,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ responseTimeout: 1200
deviceType: "gpu"

handler:
model_name: "meta-llama/Llama-2-70b-chat-hf"
model_path: "model/models--meta-llama--Llama-2-70b-chat-hf/snapshots/9ff8b00464fc439a64bb374769dec3dd627be1c2"
model_name: "meta-llama/Meta-Llama-3-70B-Instruct"
model_path: "model/models--meta-llama--Meta-Llama-3-70B-Instruct/snapshots/5fcb2901844dde3111159f24205b71c25900ffbd"
max_length: 50
max_new_tokens: 50
manual_seed: 40
fast_kernels: True
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
what is the recipe of mayonnaise?
146 changes: 0 additions & 146 deletions examples/large_models/Huggingface_accelerate/llama2/custom_handler.py

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1224,3 +1224,4 @@ Fickling
TorchServer
VirusTotal
untrusted
mpi

0 comments on commit 6ae146d

Please sign in to comment.