Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rolling batch for huggingface handler #857

Merged
merged 3 commits into from
Jun 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
separate model initializer
  • Loading branch information
sindhuvahinis committed Jun 23, 2023
commit 5c68c36d13e436ab966e477c78fe347a3640adf6
8 changes: 3 additions & 5 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,10 @@ def initialize(self, properties: dict):
self.initialized = True
return
elif self.enable_rolling_batch:
self._init_model_and_tokenizer(model_id_or_path, **kwargs)
# TODO: Add logic to call appropriate scheduler backend for rolling batch
self.rolling_batch = SchedulerRollingBatch(self.model,
self.tokenizer,
self.model_config,
self.device, properties)
self.rolling_batch = SchedulerRollingBatch(model_id_or_path,
self.device, properties,
**kwargs)
self.initialized = True
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, input_text: str, parameters: dict):
:param input_text: request's input text
"""
self.input_text = input_text
self.paramaters = parameters
self.parameters = parameters
self.next_token = None
self.last_token = False

Expand Down Expand Up @@ -70,15 +70,13 @@ class RollingBatch(ABC):

"""

def __init__(self, model, device):
def __init__(self, device):
"""
Initializes the rolling batch scheduler.

:param model: loaded model
:param device: model loaded device
:param device: device to load the model
"""

self.model = model
self.device = device
self.pending_requests = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from djl_python.scheduler import HuggingfaceBlock, BloomBlock, SearchConfig, SeqBatchScheduler
from collections import namedtuple, defaultdict
from djl_python.rolling_batch.rolling_batch import RollingBatch, Request
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig

import torch

Expand All @@ -23,36 +24,19 @@

class SchedulerRollingBatch(RollingBatch):

def __init__(self, model, tokenizer, config, device, properties):
def __init__(self, model_id_or_path, device, properties, **kwargs):
"""
Initializes the rolling batch scheduler.

:param model: loaded model
:param tokenizer: tokenizer of the model
:param config: configuration of the model
:param model_id_or_path: model id or path
:param device: model loaded device
:param properties: other properties of the model, such as decoder strategy
:param kwargs passed while loading the model
"""

super().__init__(model, device)
self.tokenizer = tokenizer
self.config = config

if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

lm_block_cls = MODEL_TYPE_2_BLOCK.get(self.config.model_type,
HuggingfaceBlock)
self.lm_block = lm_block_cls(self.model)
self.search_config = SearchConfig(
eos_token_id=self.tokenizer.eos_token,
pad_token_id=self.tokenizer.pad_token)
self.search_algorithm = properties.get('decoding_strategy',
DEFAULT_SEARCH_ALGORITHM)

self.scheduler = SeqBatchScheduler(self.lm_block,
self.search_algorithm,
self.search_config)
super().__init__(device)
self._init_model_and_tokenizer(kwargs, model_id_or_path)
self._init_scheduler(properties)

def inference(self, input_data, parameters):
"""
Expand All @@ -78,7 +62,7 @@ def preprocess_requests(self, requests):

req_id_counter = _calculate_req_id_counter(self.scheduler)
for request in requests:
parameters = request.paramaters
parameters = request.parameters
search_algorithm = parameters.get('decoding_strategy',
self.search_algorithm)
new_requests.input_texts[search_algorithm].append(
Expand All @@ -91,6 +75,38 @@ def preprocess_requests(self, requests):

return new_requests

def _init_model_and_tokenizer(self, kwargs, model_id_or_path):
self.config = AutoConfig.from_pretrained(model_id_or_path,
kwargs=kwargs)
architectures = self.config.architectures
if architectures and architectures[0].endswith(
"ForConditionalGeneration"):
raise ValueError('Seq2Seq model is not supported by scheduler')
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **kwargs)

if self.device:
self.model.to(self.device)

self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path,
padding_side="left")
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

def _init_scheduler(self, properties):
lm_block_cls = MODEL_TYPE_2_BLOCK.get(self.config.model_type,
HuggingfaceBlock)
self.lm_block = lm_block_cls(self.model)
self.search_config = SearchConfig(
eos_token_id=self.tokenizer.eos_token,
pad_token_id=self.tokenizer.pad_token)
self.search_algorithm = properties.get('decoding_strategy',
DEFAULT_SEARCH_ALGORITHM)
self.scheduler = SeqBatchScheduler(self.lm_block,
self.search_algorithm,
self.search_config)

def _prefill_and_decode(self, new_requests):

for search_algorithm in new_requests.request_ids.keys():
Expand Down