Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for stopwords in huggingface handler #1118

Merged
merged 1 commit into from
Oct 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
import json
import logging
import os
import re

import torch
from transformers import (pipeline, Pipeline, Conversation,
AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoConfig,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForQuestionAnswering)
AutoModelForQuestionAnswering, StoppingCriteria,
StoppingCriteriaList)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from peft import PeftConfig, PeftModel, PeftModelForCausalLM

Expand Down Expand Up @@ -98,6 +100,23 @@ def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
return VLLMRollingBatch
raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}")

class StopWord(StoppingCriteria):
def __init__(self, tokenizer, stop_seq):
StoppingCriteria.__init__(self)
self.tokenizer = tokenizer
self.stop_seq = stop_seq

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
decoded_input_ids = self.tokenizer.decode(input_ids[0][-len(self.stop_seq):])

matches = re.search(self.stop_seq, decoded_input_ids)

if(matches is not None):
return True
else:
return False

return True

class HuggingFaceService(object):

Expand All @@ -115,6 +134,7 @@ def __init__(self):
self.rolling_batch = None
self.model_config = None
self.peft_config = None
self.stopping_criteria_list = None

def initialize(self, properties: dict):
# model_id can point to huggingface model_id or local directory.
Expand Down Expand Up @@ -202,8 +222,40 @@ def initialize(self, properties: dict):
model_id_or_path=model_id_or_path,
kwargs=kwargs)

if("stop_sequence" in properties):
self.load_stopping_criteria_list(properties["stop_sequence"])
self.initialized = True

def parse_stop_sequence_input(self, stop_sequence):
"""
Gets a list of stop sequences by parsing the string given in
serving.properties.
Not robust against badly formatted input and commas in the stop sequence
Input: stop_sequence (string)
Output: list of strings
"""
assert stop_sequence[0] == '[' and stop_sequence[-1] == ']', "option.stop_sequence not properly formatted"
stop_sequence = stop_sequence.replace(", ", ",")
stop_seq_list = [element[1:-1] for element in stop_sequence[1:-1].split(",")]
return stop_seq_list

def load_stopping_criteria_list(self, stop_sequence):
"""
Uses current tokenizer in self.tokenizer to load StoppingCriteriaList.
Input: (str) stop_sequence - currently just one stop sequence supported
Output: none (loads into member variable)
"""
if(self.tokenizer is None):
return

stop_seq_list = self.parse_stop_sequence_input(stop_sequence)

stopwords = []
for stop_seq in stop_seq_list:
stopwords.append(StopWord(self.tokenizer, stop_seq))

self.stopping_criteria_list = StoppingCriteriaList(stopwords)

def parse_input(self, inputs):
input_data = []
input_size = []
Expand Down Expand Up @@ -493,6 +545,7 @@ def wrapped_pipeline(inputs, *args, **kwargs):
*args,
input_ids=input_tokens.input_ids,
attention_mask=input_tokens.attention_mask,
stopping_criteria=self.stopping_criteria_list,
**kwargs)
generated_text = tokenizer.batch_decode(output_tokens,
skip_special_tokens=True)
Expand Down
Loading