From 6ae146dc36cee02faec3e83ac9ff0725ad5022da Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Wed, 24 Apr 2024 09:58:57 -0700 Subject: [PATCH] Exchange Llama2 against Llama3 in HuggingFace_accelerate example (#3108) * Rename llama2 to llama in HF_accelerate * Replace llama2 with llama3 in large_models/Huggingface_accelerate * Remove code handler * Fix lint error --- .../Huggingface_accelerate/Download_model.py | 1 + .../{llama2 => llama}/Readme.md | 24 +-- .../{llama2 => llama}/config.properties | 0 .../custom_handler.py} | 83 +++++----- .../{llama2 => llama}/model-config.yaml | 5 +- .../{llama2 => llama}/requirements.txt | 0 .../llama/sample_text.txt | 1 + .../llama2/custom_handler.py | 146 ------------------ .../llama2/sample_text.txt | 1 - ts_scripts/spellcheck_conf/wordlist.txt | 1 + 10 files changed, 56 insertions(+), 206 deletions(-) rename examples/large_models/Huggingface_accelerate/{llama2 => llama}/Readme.md (52%) rename examples/large_models/Huggingface_accelerate/{llama2 => llama}/config.properties (100%) rename examples/large_models/Huggingface_accelerate/{llama2/custom_handler_code.py => llama/custom_handler.py} (65%) rename examples/large_models/Huggingface_accelerate/{llama2 => llama}/model-config.yaml (50%) rename examples/large_models/Huggingface_accelerate/{llama2 => llama}/requirements.txt (100%) create mode 100644 examples/large_models/Huggingface_accelerate/llama/sample_text.txt delete mode 100644 examples/large_models/Huggingface_accelerate/llama2/custom_handler.py delete mode 100644 examples/large_models/Huggingface_accelerate/llama2/sample_text.txt diff --git a/examples/large_models/Huggingface_accelerate/Download_model.py b/examples/large_models/Huggingface_accelerate/Download_model.py index ea854abc90..6e9872178d 100644 --- a/examples/large_models/Huggingface_accelerate/Download_model.py +++ b/examples/large_models/Huggingface_accelerate/Download_model.py @@ -47,5 +47,6 @@ def hf_model(model_str): revision=args.revision, cache_dir=args.model_path, use_auth_token=True, + ignore_patterns=["original/*"], ) print(f"Files for '{args.model_name}' is downloaded to '{snapshot_path}'") diff --git a/examples/large_models/Huggingface_accelerate/llama2/Readme.md b/examples/large_models/Huggingface_accelerate/llama/Readme.md similarity index 52% rename from examples/large_models/Huggingface_accelerate/llama2/Readme.md rename to examples/large_models/Huggingface_accelerate/llama/Readme.md index 8151b4a941..41941b1175 100644 --- a/examples/large_models/Huggingface_accelerate/llama2/Readme.md +++ b/examples/large_models/Huggingface_accelerate/llama/Readme.md @@ -1,10 +1,10 @@ -# Loading meta-llama/Llama-2-70b-chat-hf on AWS EC2 g5.24xlarge using accelerate +# Loading meta-llama/Meta-Llama-3-70B-Instruct on AWS EC2 g5.24xlarge using accelerate -This document briefs on serving large HG models with limited resource using accelerate. This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). +This document briefs on serving large HF models with limited resource using accelerate. This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). This examples uses Meta Llama-3 as an example but it works with Llama2 as well by replacing the model identifier. ### Step 1: Download model Permission -Follow [this instruction](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) to get permission +Follow [this instruction](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) to get permission Login with a Hugging Face account ``` @@ -14,30 +14,30 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ```bash -python ../Download_model.py --model_path model --model_name meta-llama/Llama-2-70b-chat-hf +python ../Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-70B-Instruct ``` -Model will be saved in the following path, `model/models--meta-llama--Llama-2-70b-chat-hf`. +Model will be saved in the following path, `model/models--meta-llama--Meta-Llama-3-70B-Instruct`. ### Step 2: Generate MAR file Add the downloaded path to " model_path:" in `model-config.yaml` and run the following. ```bash -torch-model-archiver --model-name llama2-70b-chat --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive +torch-model-archiver --model-name llama3-70b-instruct --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive ``` -If you are using conda, and notice issues with mpi4py, you would need to install openmpi-mpicc using the following +If you are using conda, and notice issues with mpi4py, you can install it with ``` -conda install -c conda-forge openmpi-mpicc +conda install mpi4py ``` ### Step 3: Add the mar file to model store ```bash mkdir model_store -mv llama2-70b-chat model_store -mv model model_store/llama2-70b-chat +mv llama3-70b-instruct model_store +mv model model_store/llama3-70b-instruct ``` ### Step 3: Start torchserve @@ -45,13 +45,13 @@ mv model model_store/llama2-70b-chat Update config.properties and start torchserve ```bash -torchserve --start --ncs --ts-config config.properties --model-store model_store --models llama2-70b-chat +torchserve --start --ncs --ts-config config.properties --model-store model_store --models llama3-70b-instruct ``` ### Step 4: Run inference ```bash -curl -v "http://localhost:8080/predictions/llama2-70b-chat" -T sample_text.txt +curl -v "http://localhost:8080/predictions/llama3-70b-instruct" -T sample_text.txt ``` results in the following output diff --git a/examples/large_models/Huggingface_accelerate/llama2/config.properties b/examples/large_models/Huggingface_accelerate/llama/config.properties similarity index 100% rename from examples/large_models/Huggingface_accelerate/llama2/config.properties rename to examples/large_models/Huggingface_accelerate/llama/config.properties diff --git a/examples/large_models/Huggingface_accelerate/llama2/custom_handler_code.py b/examples/large_models/Huggingface_accelerate/llama/custom_handler.py similarity index 65% rename from examples/large_models/Huggingface_accelerate/llama2/custom_handler_code.py rename to examples/large_models/Huggingface_accelerate/llama/custom_handler.py index d48c0cc593..be061ac70c 100644 --- a/examples/large_models/Huggingface_accelerate/llama2/custom_handler_code.py +++ b/examples/large_models/Huggingface_accelerate/llama/custom_handler.py @@ -1,9 +1,10 @@ import logging from abc import ABC +from typing import Dict import torch import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from ts.context import Context from ts.torch_handler.base_handler import BaseHandler @@ -39,26 +40,30 @@ def initialize(self, ctx: Context): seed = int(ctx.model_yaml_config["handler"]["manual_seed"]) torch.manual_seed(seed) - logger.info("Model %s loading tokenizer", ctx.model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" + logger.info("Model %s loaded tokenizer successfully", ctx.model_name) + + if self.tokenizer.vocab_size >= 128000: + quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + else: + quant_config = BitsAndBytesConfig(load_in_8bit=True) + self.model = AutoModelForCausalLM.from_pretrained( model_path, device_map="balanced", low_cpu_mem_usage=True, torch_dtype=torch.float16, - load_in_8bit=True, + quantization_config=quant_config, trust_remote_code=True, ) - if ctx.model_yaml_config["handler"]["fast_kernels"]: - from optimum.bettertransformer import BetterTransformer - - try: - self.model = BetterTransformer.transform(self.model) - except RuntimeError as error: - logger.warning( - "HuggingFace Optimum is not supporting this model,for the list of supported models, please refer to this doc,https://huggingface.co/docs/optimum/bettertransformer/overview" - ) - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - + self.device = next(iter(self.model.parameters())).device logger.info("Model %s loaded successfully", ctx.model_name) self.initialized = True @@ -72,38 +77,31 @@ def preprocess(self, requests): tuple: A tuple with two tensors: the batch of input ids and the batch of attention masks. """ - input_texts = [data.get("data") or data.get("body") for data in requests] - input_ids_batch, attention_mask_batch = [], [] - for input_text in input_texts: - input_ids, attention_mask = self.encode_input_text(input_text) - input_ids_batch.append(input_ids) - attention_mask_batch.append(attention_mask) - input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.model.device) - attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device) - return input_ids_batch, attention_mask_batch - - def encode_input_text(self, input_text): + input_texts = [self.preprocess_requests(r) for r in requests] + + logger.info("Received texts: '%s'", input_texts) + inputs = self.tokenizer( + input_texts, + max_length=self.max_length, + padding=True, + add_special_tokens=True, + return_tensors="pt", + truncation=True, + ).to(self.device) + return inputs + + def preprocess_requests(self, request: Dict): """ - Encodes a single input text using the tokenizer. + Preprocess request Args: - input_text (str): The input text to be encoded. + request (Dict): Request to be decoded. Returns: - tuple: A tuple with two tensors: the encoded input ids and the attention mask. + str: Decoded input text """ + input_text = request.get("data") or request.get("body") if isinstance(input_text, (bytes, bytearray)): input_text = input_text.decode("utf-8") - logger.info("Received text: '%s'", input_text) - inputs = self.tokenizer.encode_plus( - input_text, - max_length=self.max_length, - padding=False, - add_special_tokens=True, - return_tensors="pt", - truncation=True, - ) - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - return input_ids, attention_mask + return input_text def inference(self, input_batch): """ @@ -115,11 +113,8 @@ def inference(self, input_batch): Returns: list: A list of strings with the predicted values for each input text in the batch. """ - input_ids_batch, attention_mask_batch = input_batch - input_ids_batch = input_ids_batch.to(self.device) outputs = self.model.generate( - input_ids_batch, - attention_mask=attention_mask_batch, + **input_batch, max_length=self.max_new_tokens, ) diff --git a/examples/large_models/Huggingface_accelerate/llama2/model-config.yaml b/examples/large_models/Huggingface_accelerate/llama/model-config.yaml similarity index 50% rename from examples/large_models/Huggingface_accelerate/llama2/model-config.yaml rename to examples/large_models/Huggingface_accelerate/llama/model-config.yaml index 2e7e950e43..03e2a3b2b6 100644 --- a/examples/large_models/Huggingface_accelerate/llama2/model-config.yaml +++ b/examples/large_models/Huggingface_accelerate/llama/model-config.yaml @@ -6,9 +6,8 @@ responseTimeout: 1200 deviceType: "gpu" handler: - model_name: "meta-llama/Llama-2-70b-chat-hf" - model_path: "model/models--meta-llama--Llama-2-70b-chat-hf/snapshots/9ff8b00464fc439a64bb374769dec3dd627be1c2" + model_name: "meta-llama/Meta-Llama-3-70B-Instruct" + model_path: "model/models--meta-llama--Meta-Llama-3-70B-Instruct/snapshots/5fcb2901844dde3111159f24205b71c25900ffbd" max_length: 50 max_new_tokens: 50 manual_seed: 40 - fast_kernels: True diff --git a/examples/large_models/Huggingface_accelerate/llama2/requirements.txt b/examples/large_models/Huggingface_accelerate/llama/requirements.txt similarity index 100% rename from examples/large_models/Huggingface_accelerate/llama2/requirements.txt rename to examples/large_models/Huggingface_accelerate/llama/requirements.txt diff --git a/examples/large_models/Huggingface_accelerate/llama/sample_text.txt b/examples/large_models/Huggingface_accelerate/llama/sample_text.txt new file mode 100644 index 0000000000..b93f7033ef --- /dev/null +++ b/examples/large_models/Huggingface_accelerate/llama/sample_text.txt @@ -0,0 +1 @@ +what is the recipe of mayonnaise? diff --git a/examples/large_models/Huggingface_accelerate/llama2/custom_handler.py b/examples/large_models/Huggingface_accelerate/llama2/custom_handler.py deleted file mode 100644 index b9b51809cb..0000000000 --- a/examples/large_models/Huggingface_accelerate/llama2/custom_handler.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging -from abc import ABC - -import torch -import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer - -from ts.context import Context -from ts.torch_handler.base_handler import BaseHandler - -logger = logging.getLogger(__name__) -logger.info("Transformers version %s", transformers.__version__) - - -class LlamaHandler(BaseHandler, ABC): - """ - Transformers handler class for sequence, token classification and question answering. - """ - - def __init__(self): - super(LlamaHandler, self).__init__() - self.max_length = None - self.max_new_tokens = None - self.tokenizer = None - self.initialized = False - - def initialize(self, ctx: Context): - """In this initialize function, the HF large model is loaded and - partitioned using DeepSpeed. - Args: - ctx (context): It is a JSON Object containing information - pertaining to the model artifacts parameters. - """ - model_dir = ctx.system_properties.get("model_dir") - self.max_length = int(ctx.model_yaml_config["handler"]["max_length"]) - self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) - model_name = ctx.model_yaml_config["handler"]["model_name"] - model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}' - seed = int(ctx.model_yaml_config["handler"]["manual_seed"]) - torch.manual_seed(seed) - - logger.info("Model %s loading tokenizer", ctx.model_name) - self.model = AutoModelForCausalLM.from_pretrained( - model_path, - device_map="balanced", - low_cpu_mem_usage=True, - torch_dtype=torch.float16, - load_in_8bit=True, - trust_remote_code=True, - ) - if ctx.model_yaml_config["handler"]["fast_kernels"]: - from optimum.bettertransformer import BetterTransformer - - try: - self.model = BetterTransformer.transform(self.model) - except RuntimeError as error: - logger.warning( - "HuggingFace Optimum is not supporting this model,for the list of supported models, please refer to this doc,https://huggingface.co/docs/optimum/bettertransformer/overview" - ) - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - self.tokenizer.add_special_tokens( - { - "pad_token": "", - } - ) - self.model.resize_token_embeddings(self.model.config.vocab_size + 1) - - logger.info("Model %s loaded successfully", ctx.model_name) - self.initialized = True - - def preprocess(self, requests): - """ - Basic text preprocessing, based on the user's choice of application mode. - Args: - requests (list): A list of dictionaries with a "data" or "body" field, each - containing the input text to be processed. - Returns: - tuple: A tuple with two tensors: the batch of input ids and the batch of - attention masks. - """ - input_texts = [data.get("data") or data.get("body") for data in requests] - input_ids_batch, attention_mask_batch = [], [] - for input_text in input_texts: - input_ids, attention_mask = self.encode_input_text(input_text) - input_ids_batch.append(input_ids) - attention_mask_batch.append(attention_mask) - input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.model.device) - attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device) - return input_ids_batch, attention_mask_batch - - def encode_input_text(self, input_text): - """ - Encodes a single input text using the tokenizer. - Args: - input_text (str): The input text to be encoded. - Returns: - tuple: A tuple with two tensors: the encoded input ids and the attention mask. - """ - if isinstance(input_text, (bytes, bytearray)): - input_text = input_text.decode("utf-8") - logger.info("Received text: '%s'", input_text) - inputs = self.tokenizer.encode_plus( - input_text, - max_length=self.max_length, - padding=True, - add_special_tokens=True, - return_tensors="pt", - truncation=True, - ) - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - return input_ids, attention_mask - - def inference(self, input_batch): - """ - Predicts the class (or classes) of the received text using the serialized transformers - checkpoint. - Args: - input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch - of attention masks, as returned by the preprocess function. - Returns: - list: A list of strings with the predicted values for each input text in the batch. - """ - input_ids_batch, attention_mask_batch = input_batch - input_ids_batch = input_ids_batch.to(self.device) - outputs = self.model.generate( - input_ids_batch, - attention_mask=attention_mask_batch, - max_length=self.max_new_tokens, - ) - - inferences = self.tokenizer.batch_decode( - outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - - logger.info("Generated text: %s", inferences) - return inferences - - def postprocess(self, inference_output): - """Post Process Function converts the predicted response into Torchserve readable format. - Args: - inference_output (list): It contains the predicted response of the input text. - Returns: - (list): Returns a list of the Predictions and Explanations. - """ - return inference_output diff --git a/examples/large_models/Huggingface_accelerate/llama2/sample_text.txt b/examples/large_models/Huggingface_accelerate/llama2/sample_text.txt deleted file mode 100644 index edfe9f4c10..0000000000 --- a/examples/large_models/Huggingface_accelerate/llama2/sample_text.txt +++ /dev/null @@ -1 +0,0 @@ -what is the recipe of mayonnaise? \ No newline at end of file diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index e4d166d4b9..ce1603f138 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1224,3 +1224,4 @@ Fickling TorchServer VirusTotal untrusted +mpi \ No newline at end of file