-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathtgi_env.py
executable file
·226 lines (174 loc) · 8.23 KB
/
tgi_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
#!/usr/bin/env python
import argparse
import logging
import os
import sys
from typing import Any, Dict, List, Optional
from huggingface_hub import constants
from transformers import AutoConfig
from optimum.neuron.modeling_decoder import get_available_cores
from optimum.neuron.utils import get_hub_cached_entries
from optimum.neuron.utils.version_utils import get_neuronxcc_version
logger = logging.getLogger(__name__)
tgi_router_env_vars = ["MAX_BATCH_SIZE", "MAX_TOTAL_TOKENS", "MAX_INPUT_TOKENS", "MAX_BATCH_PREFILL_TOKENS"]
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
env_config_peering = [
("MAX_BATCH_SIZE", "batch_size"),
("MAX_TOTAL_TOKENS", "sequence_length"),
("HF_AUTO_CAST_TYPE", "auto_cast_type"),
("HF_NUM_CORES", "num_cores"),
]
# By the end of this script all env var should be specified properly
env_vars = tgi_server_env_vars + tgi_router_env_vars
available_cores = get_available_cores()
neuronxcc_version = get_neuronxcc_version()
def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
parser = argparse.ArgumentParser()
if not argv:
argv = sys.argv
# All these are params passed to tgi and intercepted here
parser.add_argument(
"--max-input-tokens", type=int, default=os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
)
parser.add_argument("--max-total-tokens", type=int, default=os.getenv("MAX_TOTAL_TOKENS", 0))
parser.add_argument("--max-batch-size", type=int, default=os.getenv("MAX_BATCH_SIZE", 0))
parser.add_argument("--max-batch-prefill-tokens", type=int, default=os.getenv("MAX_BATCH_PREFILL_TOKENS", 0))
parser.add_argument("--model-id", type=str, default=os.getenv("MODEL_ID"))
parser.add_argument("--revision", type=str, default=os.getenv("REVISION"))
args = parser.parse_known_args(argv)[0]
if not args.model_id:
raise Exception("No model id provided ! Either specify it using --model-id cmdline " "or MODEL_ID env var")
# Override env with cmdline params
os.environ["MODEL_ID"] = args.model_id
# Set all tgi router and tgi server values to consistent values as early as possible
# from the order of the parser defaults, the tgi router value can override the tgi server ones
if args.max_total_tokens > 0:
os.environ["MAX_TOTAL_TOKENS"] = str(args.max_total_tokens)
if args.max_input_tokens > 0:
os.environ["MAX_INPUT_TOKENS"] = str(args.max_input_tokens)
if args.max_batch_size > 0:
os.environ["MAX_BATCH_SIZE"] = str(args.max_batch_size)
if args.max_batch_prefill_tokens > 0:
os.environ["MAX_BATCH_PREFILL_TOKENS"] = str(args.max_batch_prefill_tokens)
if args.revision:
os.environ["REVISION"] = str(args.revision)
return args
def neuron_config_to_env(neuron_config):
with open(os.environ["ENV_FILEPATH"], "w") as f:
for env_var, config_key in env_config_peering:
f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
if not max_input_tokens:
max_input_tokens = int(neuron_config["sequence_length"]) // 2
if max_input_tokens == 0:
raise Exception("Model sequence length should be greater than 1")
f.write("export MAX_INPUT_TOKENS={}\n".format(max_input_tokens))
max_batch_prefill_tokens = os.getenv("MAX_BATCH_PREFILL_TOKENS")
if not max_batch_prefill_tokens:
max_batch_prefill_tokens = int(neuron_config["batch_size"]) * int(max_input_tokens)
f.write("export MAX_BATCH_PREFILL_TOKENS={}\n".format(max_batch_prefill_tokens))
def sort_neuron_configs(dictionary):
return -dictionary["num_cores"], -dictionary["batch_size"]
def lookup_compatible_cached_model(model_id: str, revision: Optional[str]) -> Optional[Dict[str, Any]]:
# Reuse the same mechanic as the one in use to configure the tgi server part
# The only difference here is that we stay as flexible as possible on the compatibility part
entries = get_hub_cached_entries(model_id, "inference")
logger.debug("Found %d cached entries for model %s, revision %s", len(entries), model_id, revision)
all_compatible = []
for entry in entries:
if check_env_and_neuron_config_compatibility(entry, check_compiler_version=True):
all_compatible.append(entry)
if not all_compatible:
logger.debug(
"No compatible cached entry found for model %s, env %s, available cores %s, " "neuronxcc version %s",
model_id,
get_env_dict(),
available_cores,
neuronxcc_version,
)
return None
logger.info("%d compatible neuron cached models found", len(all_compatible))
all_compatible = sorted(all_compatible, key=sort_neuron_configs)
entry = all_compatible[0]
return entry
def check_env_and_neuron_config_compatibility(neuron_config: Dict[str, Any], check_compiler_version: bool) -> bool:
logger.debug(
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
neuron_config,
)
# Local setup compat checks
if neuron_config["num_cores"] > available_cores:
logger.debug("Not enough neuron cores available to run the provided neuron config")
return False
if check_compiler_version and neuron_config["compiler_version"] != neuronxcc_version:
logger.debug(
"Compiler version conflict, the local one " "(%s) differs from the one used to compile the model (%s)",
neuronxcc_version,
neuron_config["compiler_version"],
)
return False
for env_var, config_key in env_config_peering:
neuron_config_value = str(neuron_config[config_key])
env_value = os.getenv(env_var, str(neuron_config_value))
if env_value != neuron_config_value:
logger.debug(
"The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
env_var,
config_key,
env_value,
neuron_config_value,
)
return False
max_input_tokens = int(os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)))
if max_input_tokens > 0:
sequence_length = neuron_config["sequence_length"]
if max_input_tokens >= sequence_length:
logger.debug(
"Specified max input tokens is not compatible with config sequence length " "( %s >= %s)",
max_input_tokens,
sequence_length,
)
return False
return True
def get_env_dict() -> Dict[str, str]:
d = {}
for k in env_vars:
d[k] = os.getenv(k)
return d
def main():
"""
This script determines proper default TGI env variables for the neuron precompiled models to
work properly
:return:
"""
args = parse_cmdline_and_set_env()
for env_var in env_vars:
if not os.getenv(env_var):
break
else:
logger.info("All env vars %s already set, skipping, user know what they are doing", env_vars)
sys.exit(0)
cache_dir = constants.HF_HUB_CACHE
logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
neuron_config = getattr(config, "neuron", None)
if neuron_config is not None:
compatible = check_env_and_neuron_config_compatibility(neuron_config, check_compiler_version=False)
if not compatible:
env_dict = get_env_dict()
msg = (
"Invalid neuron config and env. Config {}, env {}, available cores {}, " "neuronxcc version {}"
).format(neuron_config, env_dict, available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
else:
neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
if not neuron_config:
msg = (
"No compatible neuron config found. Provided env {}, " "available cores {}, neuronxcc version {}"
).format(get_env_dict(), available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
neuron_config_to_env(neuron_config)
if __name__ == "__main__":
main()