Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quantization parameter for lmi_dist rolling batch backend for HF #888

Merged
merged 7 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from djl_python.streaming_utils import StreamingUtils
from djl_python.rolling_batch import SchedulerRollingBatch

os.environ["BITSANDBYTES_NOWELCOME"] = "1"
maaquib marked this conversation as resolved.
Show resolved Hide resolved

ARCHITECTURES_2_TASK = {
"TapasForQuestionAnswering": "table-question-answering",
"ForQuestionAnswering": "question-answering",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@
StoppingCriteriaParameters,
)
import lmi_dist
from lmi_dist.utils.types import (
Batch,
Request,
Generation
)
from lmi_dist.utils.types import (Batch, Request, Generation)

import torch

Expand All @@ -36,6 +32,8 @@
"LlamaForCausalLM": FlashCausalLMBatch
}

QUANTIZATION_SUPPORT_ALGO = ["bitsandbytes"]


def get_batch_cls_from_architecture(architecture):
if architecture in ARCHITECTURE_2_BATCH_CLS:
Expand Down Expand Up @@ -63,15 +61,27 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
self.cache: Batch = None

def _init_model(self, kwargs, model_id_or_path):
self.config = AutoConfig.from_pretrained(model_id_or_path,
**kwargs)
self.batch_cls = get_batch_cls_from_architecture(self.config.architectures[0])
self.config = AutoConfig.from_pretrained(model_id_or_path, **kwargs)
self.batch_cls = get_batch_cls_from_architecture(
self.config.architectures[0])
sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1
self.model = get_model(model_id_or_path,
revision=None,
sharded=sharded,
quantize=None,
trust_remote_code=kwargs.get("trust_remote_code"))
quantize = self.properties.get("quantize", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the original properties, we have option.load_in_8bit, can we reuse this param?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_in_8bit is a boolean. Assuming we add gptq support in next release we need a parameter which can take the quantization algo name instead of just a boolean. @lanking520 thoughts?

dtype = self.properties.get("dtype", None)
if quantize is not None and dtype is not None:
raise ValueError(
f"Can't set both dtype: {dtype} and quantize: {quantize}")
if quantize is not None and quantize not in QUANTIZATION_SUPPORT_ALGO:
maaquib marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Invalid value for quantize: {quantize}. Valid values are: {QUANTIZATION_SUPPORT_ALGO}"
)
if quantize is None and dtype == "int8":
quantize = "bitsandbytes"
self.model = get_model(
model_id_or_path,
revision=None,
sharded=sharded,
quantize=quantize,
trust_remote_code=kwargs.get("trust_remote_code"))

def inference(self, input_data, parameters):
"""
Expand All @@ -91,15 +101,18 @@ def inference(self, input_data, parameters):
def _prefill_and_decode(self, new_batch):
# prefill step
if new_batch:
generations, prefill_next_batch = self.model.generate_token(new_batch)
generations, prefill_next_batch = self.model.generate_token(
new_batch)

if self.cache:
decode_generations, decode_next_batch = self.model.generate_token(self.cache)
decode_generations, decode_next_batch = self.model.generate_token(
self.cache)
generations.extend(decode_generations)

# concatenate with the existing batch of the model
if decode_next_batch:
self.cache = self.model.batch_type.concatenate([prefill_next_batch, decode_next_batch])
self.cache = self.model.batch_type.concatenate(
[prefill_next_batch, decode_next_batch])
else:
self.cache = prefill_next_batch
else:
Expand All @@ -108,7 +121,10 @@ def _prefill_and_decode(self, new_batch):
generations, next_batch = self.model.generate_token(self.cache)
self.cache = next_batch

generation_dict = {generation.request_id: generation for generation in generations}
generation_dict = {
generation.request_id: generation
for generation in generations
}

req_ids = []
for r in self.pending_requests:
Expand All @@ -127,22 +143,25 @@ def preprocess_requests(self, requests, **kwargs):
for r in requests:
param = r.parameters
parameters = NextTokenChooserParameters(
temperature=param.get("temperature", 0.5), # TODO: Find a better place to put default values
temperature=param.get(
"temperature",
0.5), # TODO: Find a better place to put default values
repetition_penalty=param.get("repetition_penalty", 1.0),
top_k=param.get("top_k", 4),
top_p=param.get("top_p", 1.0),
typical_p=param.get("typical_p", 1.0),
do_sample=param.get("do_sample", False),
)
stop_parameters = StoppingCriteriaParameters(stop_sequences=param.get("stop_sequences", []),
max_new_tokens=param.get("max_new_tokens", 30))
stop_parameters = StoppingCriteriaParameters(
stop_sequences=param.get("stop_sequences", []),
max_new_tokens=param.get("max_new_tokens", 30))

preprocessed_requests.append(lmi_dist.utils.types.Request(
id=r.id,
inputs=r.input_text,
parameters=parameters,
stopping_parameters=stop_parameters
))
preprocessed_requests.append(
lmi_dist.utils.types.Request(
id=r.id,
inputs=r.input_text,
parameters=parameters,
stopping_parameters=stop_parameters))

if preprocessed_requests:
batch = Batch(id=self.batch_id_counter,
Expand All @@ -151,10 +170,7 @@ def preprocess_requests(self, requests, **kwargs):
self.batch_id_counter += 1

return self.batch_cls.get_batch(
batch,
self.model.tokenizer,
kwargs.get("torch_dtype", torch.float16),
self.device
)
batch, self.model.tokenizer,
kwargs.get("torch_dtype", torch.float16), self.device)
else:
return None