Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Huggingface configurations refactoring #1283

Merged
merged 2 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 40 additions & 118 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from djl_python.outputs import Output
from djl_python.streaming_utils import StreamingUtils

from djl_python.properties_manager.properties import StreamingEnum, is_rolling_batch_enabled, is_streaming_enabled
from djl_python.properties_manager.hf_properties import HuggingFaceProperties

ARCHITECTURES_2_TASK = {
"TapasForQuestionAnswering": "table-question-answering",
"ForQuestionAnswering": "question-answering",
Expand Down Expand Up @@ -69,22 +72,6 @@
}


def get_torch_dtype_from_str(dtype: str):
if dtype == "auto":
return dtype
if dtype == "fp32":
return torch.float32
if dtype == "fp16":
return torch.float16
if dtype == "bf16":
return torch.bfloat16
if dtype == "int8":
return torch.int8
if dtype is None:
return None
raise ValueError(f"Invalid data type: {dtype}")


def enable_flash():
if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
Expand Down Expand Up @@ -143,107 +130,42 @@ def __init__(self):
self.hf_pipeline = None
self.hf_pipeline_unwrapped = None
self.initialized = False
self.enable_streaming = None
self.model = None
self.device = None
self.tokenizer = None
self.trust_remote_code = os.environ.get("HF_TRUST_REMOTE_CODE",
"FALSE").lower() == 'true'
self.rolling_batch_type = None
self.rolling_batch = None
self.model_config = None
self.peft_config = None
self.stopping_criteria_list = None
self.disable_flash_attn = None
self.adapters = None
self.hf_configs = None

def initialize(self, properties: dict):
# model_id can point to huggingface model_id or local directory.
# If option.model_id points to a s3 bucket, we download it and set model_id to the download directory.
# Otherwise we assume model artifacts are in the model_dir
model_id_or_path = properties.get("model_id") or properties.get(
"model_dir")
device_id = int(properties.get("device_id", "-1"))
self.device = f"cuda:{device_id}" if device_id >= 0 else None
task = properties.get("task")
tp_degree = int(properties.get("tensor_parallel_degree", "-1"))
self.enable_streaming = properties.get("enable_streaming", None)
if self.enable_streaming and self.enable_streaming.lower() == "false":
self.enable_streaming = None
if "trust_remote_code" in properties:
self.trust_remote_code = properties.get(
"trust_remote_code").lower() == "true"
# HF Acc handling
kwargs = {"trust_remote_code": self.trust_remote_code}
# https://huggingface.co/docs/accelerate/usage_guides/big_modeling#designing-a-device-map
if "device_map" in properties:
kwargs["device_map"] = properties.get("device_map")
self.device = None
logging.info(f"Using device map {kwargs['device_map']}")
elif tp_degree > 0 and torch.cuda.device_count() > 0:
kwargs["device_map"] = "auto"
self.device = None
world_size = torch.cuda.device_count()
assert world_size == tp_degree, f"TP degree ({tp_degree}) doesn't match available GPUs ({world_size})"
logging.info(f"Using {world_size} gpus")
if "load_in_8bit" in properties:
if "device_map" not in kwargs:
raise ValueError(
"device_map should set when load_in_8bit is set")
kwargs["load_in_8bit"] = properties.get(
"load_in_8bit").lower() == 'true'
if "load_in_4bit" in properties:
if "device_map" not in kwargs:
raise ValueError(
"device_map should set when load_in_4bit is set")
kwargs["load_in_4bit"] = properties.get(
"load_in_4bit").lower() == 'true'
if "low_cpu_mem_usage" in properties:
kwargs["low_cpu_mem_usage"] = properties.get(
"low_cpu_mem_usage").lower() == 'true'

if "data_type" in properties:
kwargs["torch_dtype"] = get_torch_dtype_from_str(
properties.get("data_type"))
if "dtype" in properties:
kwargs["torch_dtype"] = get_torch_dtype_from_str(
properties.get("dtype"))
if "revision" in properties:
kwargs["revision"] = properties.get('revision')
self.disable_flash_attn = properties.get("disable_flash_attn",
"true").lower() == 'true'
self.rolling_batch_type = properties.get("rolling_batch", None)

self._read_model_config(model_id_or_path,
properties.get('revision', None))

if self.rolling_batch_type:
if "output_formatter" in properties:
kwargs["output_formatter"] = properties.get("output_formatter")
if "waiting_steps" in properties:
kwargs["waiting_steps"] = int(properties.get("waiting_steps"))
self.rolling_batch_type = self.rolling_batch_type.lower()
is_mpi = properties.get("engine") != "Python"
if is_mpi:
self.device = int(os.getenv("LOCAL_RANK", 0))
self.hf_configs = HuggingFaceProperties(**properties)
self._read_model_config(self.hf_configs.model_id_or_path,
self.hf_configs.revision)

if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
_rolling_batch_cls = get_rolling_batch_class_from_str(
self.rolling_batch_type, is_mpi, self.model_config)
self.rolling_batch = _rolling_batch_cls(model_id_or_path,
self.device, properties,
**kwargs)
self.hf_configs.rolling_batch.value, self.hf_configs.is_mpi,
self.model_config)
self.rolling_batch = _rolling_batch_cls(
self.hf_configs.model_id_or_path, self.hf_configs.device,
properties, **self.hf_configs.kwargs)
self.initialized = True
return
elif self.enable_streaming:
self._init_model_and_tokenizer(model_id_or_path, **kwargs)
elif is_streaming_enabled(self.hf_configs.enable_streaming):
self._init_model_and_tokenizer(self.hf_configs.model_id_or_path,
**self.hf_configs.kwargs)
self.initialized = True
return

if not task:
task = self.infer_task_from_model_architecture()
if not self.hf_configs.task:
self.hf_configs.task = self.infer_task_from_model_architecture()

self.hf_pipeline = self.get_pipeline(task=task,
model_id_or_path=model_id_or_path,
kwargs=kwargs)
self.hf_pipeline = self.get_pipeline(
task=self.hf_configs.task,
model_id_or_path=self.hf_configs.model_id_or_path,
kwargs=self.hf_configs.kwargs)

if "stop_sequence" in properties:
self.load_stopping_criteria_list(properties["stop_sequence"])
Expand Down Expand Up @@ -342,7 +264,7 @@ def inference(self, inputs):
if len(input_data) == 0:
for i in range(len(batch)):
err = errors.get(i)
if self.rolling_batch_type:
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
err = {"data": "", "last": True, "code": 424, "error": err}
outputs.add(Output.binary_encode(err),
key="data",
Expand All @@ -351,7 +273,7 @@ def inference(self, inputs):
outputs.add(err, key="data", batch_index=i)
return outputs

if self.rolling_batch_type:
if is_rolling_batch_enabled(self.hf_configs.rolling_batch):
if inputs.get_property("reset_rollingbatch"):
self.rolling_batch.reset()
result = self.rolling_batch.inference(input_data, parameters)
Expand All @@ -373,22 +295,22 @@ def inference(self, inputs):
if content_type:
outputs.add_property("content-type", content_type)
return outputs
elif self.enable_streaming:
elif is_streaming_enabled(self.hf_configs.enable_streaming):
if len(batch) > 1:
raise NotImplementedError(
"Dynamic batch not supported for generic streaming")
outputs.add_property("content-type", "application/jsonlines")
if self.enable_streaming == "huggingface":
if self.hf_configs.enable_streaming.value == StreamingEnum.huggingface.value:
outputs.add_stream_content(
StreamingUtils.use_hf_default_streamer(
self.model, self.tokenizer, input_data, self.device,
**parameters[0]))
self.model, self.tokenizer, input_data,
self.hf_configs.device, **parameters[0]))
else:
stream_generator = StreamingUtils.get_stream_generator(
"Accelerate")
outputs.add_stream_content(
stream_generator(self.model, self.tokenizer, input_data,
self.device, **parameters[0]))
self.hf_configs.device, **parameters[0]))
return outputs

if not all(p == parameters[0] for p in parameters):
Expand Down Expand Up @@ -463,22 +385,22 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
hf_pipeline = pipeline(task=task,
tokenizer=self.tokenizer,
model=self.model,
device=self.device)
device=self.hf_configs.device)
else:
self.tokenizer = AutoTokenizer.from_pretrained(
model_id_or_path, revision=kwargs.get('revision', None))
hf_pipeline = pipeline(task=task,
tokenizer=self.tokenizer,
model=model_id_or_path,
device=self.device,
device=self.hf_configs.device,
**kwargs)
self.model = hf_pipeline.model
else:
self._init_model_and_tokenizer(model_id_or_path, **kwargs)
hf_pipeline = pipeline(task=task,
model=self.model,
tokenizer=self.tokenizer,
device=self.device)
device=self.hf_configs.device)

# wrap specific pipeline to support better ux
if task == "conversational":
Expand Down Expand Up @@ -515,7 +437,7 @@ def _init_model_and_tokenizer(self, model_id_or_path: str, **kwargs):
else:
model_cls = AutoModelForCausalLM
if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash(
) and not self.disable_flash_attn:
) and not self.hf_configs.disable_flash_attn:
kwargs['use_flash_attention_2'] = True

if self.peft_config is not None:
Expand All @@ -533,8 +455,8 @@ def _init_model_and_tokenizer(self, model_id_or_path: str, **kwargs):
else:
self.model = model_cls.from_pretrained(model_id_or_path, **kwargs)

if self.device:
self.model.to(self.device)
if self.hf_configs.device:
self.model.to(self.hf_configs.device)

@staticmethod
def wrap_conversation_pipeline(hf_pipeline):
Expand Down Expand Up @@ -562,8 +484,8 @@ def wrapped_pipeline(inputs, *args, **kwargs):
model = hf_pipeline.model
tokenizer = hf_pipeline.tokenizer
input_tokens = tokenizer(inputs, padding=True, return_tensors="pt")
if self.device:
input_tokens = input_tokens.to(self.device)
if self.hf_configs.device:
input_tokens = input_tokens.to(self.hf_configs.device)
else:
input_tokens = input_tokens.to(model.device)

Expand Down Expand Up @@ -599,7 +521,7 @@ def _read_model_config(self, model_config_path: str, revision=None):
try:
self.model_config = AutoConfig.from_pretrained(
model_config_path,
trust_remote_code=self.trust_remote_code,
trust_remote_code=self.hf_configs.trust_remote_code,
revision=revision)
except OSError:
logging.warning(
Expand All @@ -608,7 +530,7 @@ def _read_model_config(self, model_config_path: str, revision=None):
self.peft_config = PeftConfig.from_pretrained(model_config_path)
self.model_config = AutoConfig.from_pretrained(
self.peft_config.base_model_name_or_path,
trust_remote_code=self.trust_remote_code)
trust_remote_code=self.hf_configs.trust_remote_code)
except Exception as e:
logging.error(
f"{model_config_path} does not contain a config.json or adapter_config.json for lora models. "
Expand Down
Loading
Loading