Skip to content

Commit

Permalink
Fix for Langchain (nod-ai#1694)
Browse files Browse the repository at this point in the history
For CPU, remove max time stopping criteria
Fix web UI issue
  • Loading branch information
vivekkhandelwal1 authored Jul 26, 2023
1 parent 9d399eb commit 776a9c2
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
27 changes: 23 additions & 4 deletions apps/language_models/langchain/h2oai_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl

def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:

def brevitas〇matmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
Expand All @@ -39,20 +47,30 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
raise ValueError("Input shapes not supported.")


def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def brevitas〇matmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype


def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
return


brevitas_matmul_rhs_group_quant_library = [
brevitas〇matmul_rhs_group_quant〡shape,
brevitas〇matmul_rhs_group_quant〡dtype,
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
]

global_device = "cuda"
global_precision = "fp16"
Expand Down Expand Up @@ -541,6 +559,7 @@ def generate_new_token(self):
return next_token

def generate_token(self, **generate_kwargs):
del generate_kwargs["max_time"]
self.truncated_input_ids = []

generation_config_ = GenerationConfig.from_model_config(
Expand Down
6 changes: 2 additions & 4 deletions apps/language_models/langchain/langchain_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# for generate (gradio server) and finetune
datasets==2.13.0
sentencepiece==0.1.99
# gradio==3.37.0
huggingface_hub==0.16.4
appdirs==1.4.4
fire==0.5.0
docutils==0.20.1
# torch==2.0.1; sys_platform != "darwin" and platform_machine != "arm64"
evaluate==0.4.0
rouge_score==0.1.2
sacrebleu==2.3.1
Expand All @@ -21,7 +19,7 @@ bitsandbytes==0.39.0
accelerate==0.20.3
peft==0.4.0
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
# transformers==4.30.2
transformers==4.30.2
tokenizers==0.13.3
APScheduler==3.10.1

Expand Down Expand Up @@ -67,7 +65,7 @@ tiktoken==0.4.0
openai==0.27.8

# optional for chat with PDF
langchain==0.0.235
langchain==0.0.202
pypdf==3.12.2
# avoid textract, requires old six
#textract==1.6.5
Expand Down
2 changes: 1 addition & 1 deletion apps/stable_diffusion/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
upscaler_status,
]
)
with gr.TabItem(label="DocuChat(Experimental)", id=9):
with gr.TabItem(label="DocuChat(Experimental)", id=10):
h2ogpt_web.render()

# send to buttons
Expand Down

0 comments on commit 776a9c2

Please sign in to comment.