Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit 2f1f52f

Browse files
HamidShojanazeriUbuntuUbuntu
authored
Pippy deferred init (#2310)
* remove HF auth * update steps * add model checkpoint path * add model checkpoint path * add deferred init * add deferred init * fix keys * clean up * adding torchpippy * adding torchpippy * add comment for replace checkpoint path * add comment for checkpoint path * add checks for configs * fixing thread numbers * fixing max_new_tokens * adding max_new_tokens * fix padding * revert tokenizer changes * fixing the response size * making index file optional * fixing new tokens * fixing new tokens * fixing the output issue * add check for torch version * fixing the index file path * extend the word list * moving the script to parent direcrtory * change the path to download script * moving to utls * adding utls * allowing only related patterns * setting default chunks to 1 --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-9-21.us-west-2.compute.internal> Co-authored-by: Ubuntu <ubuntu@ip-172-31-5-255.us-west-2.compute.internal>
1 parent 614bfc0 commit 2f1f52f

File tree

7 files changed

+121
-38
lines changed

7 files changed

+121
-38
lines changed

examples/large_models/Huggingface_pippy/Readme.md

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,25 @@ PiPPy provides pipeline parallelism for serving large models that would not fit
66

77
## How to serve your large HuggingFace models with PiPPy in Torchserve?
88

9-
We use a Torchserve custom handler that inherits from base_pippy_handler to load the model and define our logic for preprocess, inference and post processing. This is basically very similar to your evaluation process.
9+
We use a Torchserve custom handler that inherits from base_pippy_handler to load the model and define our logic for preprocess, inference and post processing. This is basically very similar to your evaluation process. Following settings has been tested on g5.12xlarge EC2 instance which has 4xA10 GPUs.
1010

11-
### Step 1: Download model
11+
To run this example we need to have torchpippy installed. This has been added to the requirement.txt which can be bundled during model packaging.
1212

13-
Login into huggingface hub with token by running the below command
13+
Generally to install torchpippy you can run following
1414

1515
```bash
16-
huggingface-cli login
16+
pip install torchpippy
17+
1718
```
18-
paste the token generated from huggingface hub.
19+
20+
### Step 1: Download model
1921

2022
```bash
21-
python Download_model.py --model_name facebook/opt-6.7b
23+
python ../utils/Download_model.py --model_name facebook/opt-30b
2224
```
2325
The script prints the path where the model is downloaded as below. This is an example and in your workload you want to use your actual trained model checkpoints.
2426

25-
`model/models--bigscience-bloom-7b1/snapshots/5546055f03398095e385d7dc625e636cc8910bf2/`
27+
`model/models--facebook--opt-30b/snapshots/ceea0a90ac0f6fae7c2c34bcb40477438c152546/`
2628

2729
The downloaded model is around 14GB.
2830

@@ -46,37 +48,42 @@ pippy:
4648
input_names: ['input_ids'] # input arg names to the model, this is required for FX tracing
4749
model_type: "HF" # set the model type to HF if you are using Huggingface model other wise leave it blank or any other model you use.
4850
rpc_timeout: 1800
51+
num_worker_threads: 512 #number of threads for rpc worker usually 512 is a good number
4952

5053
handler:
5154
max_length: 80 # max length of tokens for tokenizer in the handler
55+
model_name: "/home/ubuntu/serve/examples/large_models/Huggingface_pippy/model/models--facebook--opt-30b/snapshots/ceea0a90ac0f6fae7c2c34bcb40477438c152546" #the path to the checkpoints, in this example downloaded file. Please change to your model path.
56+
index_file_name: 'pytorch_model.bin.index.json' # index json file name in the model checkpoint folder, that keeps information of distributed checkpoints
57+
manual_seed: 40
58+
dtype: fp16 # data type to load your model checkpoint, supported fp32, fp16, bf16
5259
```
5360

5461
### Step 3: Generate Tar/ MAR file
5562

56-
Navigate up to `Huggingface_Largemodels` directory.
63+
Navigate up to `largemodels` directory. Here as bundling the large model checkpoints is very time consuming, we are passing model checkpoint path in the model_config.yaml as shown above. This let us make the packaging very fast, for production settings, the large models can be put in some shared location and used from there in the model-config.
5764

5865
```bash
59-
torch-model-archiver --model-name bloom --version 1.0 --handler pippy_handler.py --extra-files model/models--facebook--opt-iml-max-1.3b/snapshots/d60fa58f50def19751da2075791da359ca19d273 -r requirements.txt --config-file model-config.yaml --archive-format tgz
66+
torch-model-archiver --model-name opt --version 1.0 --handler pippy_handler.py -r requirements.txt --config-file model-config.yaml --archive-format tgz
6067

6168
```
6269

6370
### Step 4: Add the mar file to model store
6471

6572
```bash
6673
mkdir model_store
67-
mv bloom.mar model_store
74+
mv opt.tar.gz model_store
6875
```
6976

7077
### Step 5: Start torchserve
7178

7279
Update config.properties and start torchserve
7380

7481
```bash
75-
torchserve --ncs --start --model-store model_store --models bloom.mar
82+
torchserve --ncs --start --model-store model_store --models opt.tar.gz
7683
```
7784

7885
### Step 6: Run inference
7986

8087
```bash
81-
curl -v "http://localhost:8080/predictions/bloom" -T sample_text.txt
88+
curl -v "http://localhost:8080/predictions/opt" -T sample_text.txt
8289
```

examples/large_models/Huggingface_pippy/model-config.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#frontend settings
22
minWorkers: 1
33
maxWorkers: 1
4-
maxBatchDelay: 100
5-
responseTimeout: 120
4+
maxBatchDelay: 200
5+
responseTimeout: 300
66
parallelType: "pp"
77
deviceType: "gpu"
88
torchrun:
@@ -14,8 +14,12 @@ pippy:
1414
model_type: "HF"
1515
chunks: 1
1616
input_names: ["input_ids"]
17-
num_worker_threads: 512
17+
num_worker_threads: 128
1818

1919
handler:
20+
model_path: "/home/ubuntu/serve/examples/large_models/Huggingface_pippy/model/models--facebook--opt-30b/snapshots/ceea0a90ac0f6fae7c2c34bcb40477438c152546"
21+
index_filename: 'pytorch_model.bin.index.json'
2022
max_length: 50
23+
max_new_tokens: 60
2124
manual_seed: 40
25+
dtype: fp16

examples/large_models/Huggingface_pippy/pippy_handler.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from abc import ABC
44

5+
import packaging.version
56
import requests
67
import torch
78
import transformers
@@ -12,6 +13,12 @@
1213

1314
logger = logging.getLogger(__name__)
1415
logger.info("Transformers version %s", transformers.__version__)
16+
if packaging.version.parse(torch.__version__) >= packaging.version.parse("2.0.0"):
17+
logger.info("PyTorch version is 2.0.0 or greater")
18+
else:
19+
logger.info(
20+
"PyTorch version is less than 2.0.0, initializing with meta device needs PyTorch 2.0.0 and greater"
21+
)
1522

1623

1724
class TransformersSeqClassifierHandler(BasePippyHandler, ABC):
@@ -36,18 +43,43 @@ def initialize(self, ctx):
3643
model_dir = properties.get("model_dir")
3744
self.device = self.local_rank
3845

46+
model_path = ctx.model_yaml_config["handler"]["model_path"]
3947
seed = ctx.model_yaml_config["handler"]["manual_seed"]
48+
dtype_str = ctx.model_yaml_config["handler"]["dtype"]
4049
torch.manual_seed(seed)
4150

42-
self.model = AutoModelForCausalLM.from_pretrained(model_dir, use_cache=False)
51+
dtypes = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
4352

44-
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, return_tensors="pt")
53+
dtype = dtypes.get(dtype_str, torch.float32)
54+
if dtype != torch.float32 and dtype_str not in dtypes:
55+
logger.info(
56+
f"Unsupported data type {dtype_str}, "
57+
"please submit a PR to support it. Falling back to fp32 now."
58+
)
59+
60+
skip_init_start = time.perf_counter()
61+
with torch.device("meta"):
62+
self.model = AutoModelForCausalLM.from_pretrained(
63+
model_path, use_cache=False, torch_dtype=dtype
64+
)
65+
skip_init_end = time.perf_counter()
66+
logger.info(
67+
f" init model time on meta device took {skip_init_end - skip_init_start} seconds"
68+
)
69+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, return_tensors="pt")
70+
self.tokenizer.pad_token = self.tokenizer.eos_token
4571

4672
self.max_length = ctx.model_yaml_config["handler"]["max_length"]
73+
self.max_new_tokens = ctx.model_yaml_config["handler"]["max_new_tokens"]
4774

4875
logger.info("Instantiating model Pipeline")
49-
model_init_start = time.time()
76+
pippy_compile_time_start = time.perf_counter()
5077
self.model = get_pipeline_driver(self.model, self.world_size, ctx)
78+
pippy_compile_time_end = time.perf_counter()
79+
80+
logger.info(
81+
f" pippy compile time took {pippy_compile_time_end- pippy_compile_time_start} seconds on rank {self.local_rank}"
82+
)
5183

5284
logger.info("Transformer model from path %s loaded successfully", model_dir)
5385

@@ -64,14 +96,12 @@ def preprocess(self, requests):
6496
attention masks.
6597
"""
6698
input_texts = [data.get("data") or data.get("body") for data in requests]
67-
input_ids_batch, attention_mask_batch = [], []
99+
input_ids_batch = []
68100
for input_text in input_texts:
69-
input_ids, attention_mask = self.encode_input_text(input_text)
101+
input_ids = self.encode_input_text(input_text)
70102
input_ids_batch.append(input_ids)
71-
attention_mask_batch.append(attention_mask)
72103
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.device)
73-
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device)
74-
return input_ids_batch, attention_mask_batch
104+
return input_ids_batch
75105

76106
def encode_input_text(self, input_text):
77107
"""
@@ -92,8 +122,7 @@ def encode_input_text(self, input_text):
92122
return_tensors="pt",
93123
)
94124
input_ids = inputs["input_ids"]
95-
attention_mask = inputs["attention_mask"]
96-
return input_ids, attention_mask
125+
return input_ids
97126

98127
def inference(self, input_batch):
99128
"""
@@ -105,21 +134,18 @@ def inference(self, input_batch):
105134
Returns:
106135
list: A list of strings with the predicted values for each input text in the batch.
107136
"""
108-
input_ids_batch, attention_mask_batch = input_batch
137+
input_ids_batch = input_batch
109138
input_ids_batch = input_ids_batch.to(self.device)
110139
outputs = self.model.generate(
111140
input_ids_batch,
112-
attention_mask=attention_mask_batch,
113-
max_length=30,
141+
max_length=self.max_new_tokens,
142+
)
143+
generated_text = self.tokenizer.batch_decode(
144+
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
114145
)
115146

116-
inferences = [
117-
self.tokenizer.batch_decode(
118-
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
119-
)
120-
]
121-
logger.info("Generated text: %s", inferences)
122-
return inferences
147+
logger.info("Generated text: %s", generated_text)
148+
return generated_text
123149

124150
def postprocess(self, inference_output):
125151
"""Post Process Function converts the predicted response into Torchserve readable format.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
transformers
2-
2+
torchpippy

examples/large_models/Huggingface_pippy/Download_model.py renamed to examples/large_models/utils/Download_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ def hf_model(model_str):
4141
)
4242
parser.add_argument("--revision", "-r", type=str, default="main", help="Revision")
4343
args = parser.parse_args()
44+
# Only download pytorch checkpoint files
45+
allow_patterns = ["*.json", "*.pt", "*.bin", "*.txt", "*.model"]
4446

4547
snapshot_path = snapshot_download(
4648
repo_id=args.model_name,
4749
revision=args.revision,
50+
allow_patterns=allow_patterns,
4851
cache_dir=args.model_path,
49-
use_auth_token=True,
52+
use_auth_token=False,
5053
)
5154
print(f"Files for '{args.model_name}' is downloaded to '{snapshot_path}'")

ts/handler_utils/distributed/pt_pippy.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,49 @@ def get_pipeline_driver(model, world_size, ctx):
5151
torch.nn.Sequential: The pipeline driver for the model.
5252
"""
5353
# Extract configuration parameters from the context
54-
chunks = ctx.model_yaml_config["pippy"]["chunks"]
54+
55+
# Check that the "pippy" and "handler" keys are present in the YAML config
56+
assert "pippy" in ctx.model_yaml_config, "Missing 'pippy' key in YAML config"
57+
assert "handler" in ctx.model_yaml_config, "Missing 'handler' key in YAML config"
58+
59+
# Check that the required keys are present in the "pippy" section
60+
61+
assert (
62+
"input_names" in ctx.model_yaml_config["pippy"]
63+
), "Missing 'input_names' key in YAML config"
64+
assert (
65+
"model_type" in ctx.model_yaml_config["pippy"]
66+
), "Missing 'model_type' key in YAML config"
67+
68+
# Check that the required keys are present in the "handler" section
69+
assert (
70+
"model_path" in ctx.model_yaml_config["handler"]
71+
), "Missing 'model_path' key in YAML config"
72+
73+
# Set variables from the config
74+
5575
input_names = ctx.model_yaml_config["pippy"]["input_names"]
5676
model_type = ctx.model_yaml_config["pippy"]["model_type"]
77+
model_path = ctx.model_yaml_config["handler"]["model_path"]
78+
try:
79+
chunks = ctx.model_yaml_config["pippy"]["chunks"]
80+
except KeyError:
81+
chunks = 1
82+
try:
83+
index_filename = ctx.model_yaml_config["handler"]["index_filename"]
84+
except KeyError:
85+
index_filename = None
86+
87+
# Check that the index file exists
88+
if index_filename is not None:
89+
index_file_path = os.path.join(model_path, index_filename)
90+
assert os.path.exists(
91+
index_file_path
92+
), f"Index file '{index_file_path}' not found"
93+
else:
94+
index_file_path = None
5795

96+
checkpoint_prefix = None
5897
# Set the model to evaluation mode
5998
model.eval()
6099

@@ -83,6 +122,8 @@ def get_pipeline_driver(model, world_size, ctx):
83122
split_policy=split_policy,
84123
tracer=tracer,
85124
concrete_args=concrete_args,
125+
index_filename=index_file_path,
126+
checkpoint_prefix=checkpoint_prefix,
86127
)
87128

88129
# Inject the pipeline forward method if necessary

ts_scripts/spellcheck_conf/wordlist.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,3 +1043,5 @@ QueueTime
10431043
WorkerLoadTime
10441044
WorkerName
10451045
WorkerThreadTime
1046+
largemodels
1047+
torchpippy

0 commit comments

Comments
 (0)