Skip to content

Commit

Permalink
Add support for stopwords in huggingface handler (#1118)
Browse files Browse the repository at this point in the history
  • Loading branch information
ydm-amazon authored Oct 4, 2023
1 parent fbd47b1 commit 7c9ea81
Showing 1 changed file with 54 additions and 1 deletion.
55 changes: 54 additions & 1 deletion engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
import json
import logging
import os
import re

import torch
from transformers import (pipeline, Pipeline, Conversation,
AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoConfig,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForQuestionAnswering)
AutoModelForQuestionAnswering, StoppingCriteria,
StoppingCriteriaList)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from peft import PeftConfig, PeftModel, PeftModelForCausalLM

Expand Down Expand Up @@ -98,6 +100,23 @@ def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
return VLLMRollingBatch
raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}")

class StopWord(StoppingCriteria):
def __init__(self, tokenizer, stop_seq):
StoppingCriteria.__init__(self)
self.tokenizer = tokenizer
self.stop_seq = stop_seq

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
decoded_input_ids = self.tokenizer.decode(input_ids[0][-len(self.stop_seq):])

matches = re.search(self.stop_seq, decoded_input_ids)

if(matches is not None):
return True
else:
return False

return True

class HuggingFaceService(object):

Expand All @@ -115,6 +134,7 @@ def __init__(self):
self.rolling_batch = None
self.model_config = None
self.peft_config = None
self.stopping_criteria_list = None

def initialize(self, properties: dict):
# model_id can point to huggingface model_id or local directory.
Expand Down Expand Up @@ -202,8 +222,40 @@ def initialize(self, properties: dict):
model_id_or_path=model_id_or_path,
kwargs=kwargs)

if("stop_sequence" in properties):
self.load_stopping_criteria_list(properties["stop_sequence"])
self.initialized = True

def parse_stop_sequence_input(self, stop_sequence):
"""
Gets a list of stop sequences by parsing the string given in
serving.properties.
Not robust against badly formatted input and commas in the stop sequence
Input: stop_sequence (string)
Output: list of strings
"""
assert stop_sequence[0] == '[' and stop_sequence[-1] == ']', "option.stop_sequence not properly formatted"
stop_sequence = stop_sequence.replace(", ", ",")
stop_seq_list = [element[1:-1] for element in stop_sequence[1:-1].split(",")]
return stop_seq_list

def load_stopping_criteria_list(self, stop_sequence):
"""
Uses current tokenizer in self.tokenizer to load StoppingCriteriaList.
Input: (str) stop_sequence - currently just one stop sequence supported
Output: none (loads into member variable)
"""
if(self.tokenizer is None):
return

stop_seq_list = self.parse_stop_sequence_input(stop_sequence)

stopwords = []
for stop_seq in stop_seq_list:
stopwords.append(StopWord(self.tokenizer, stop_seq))

self.stopping_criteria_list = StoppingCriteriaList(stopwords)

def parse_input(self, inputs):
input_data = []
input_size = []
Expand Down Expand Up @@ -493,6 +545,7 @@ def wrapped_pipeline(inputs, *args, **kwargs):
*args,
input_ids=input_tokens.input_ids,
attention_mask=input_tokens.attention_mask,
stopping_criteria=self.stopping_criteria_list,
**kwargs)
generated_text = tokenizer.batch_decode(output_tokens,
skip_special_tokens=True)
Expand Down

0 comments on commit 7c9ea81

Please sign in to comment.