From b0136593dff3dd851e14e2d030d7cfbcd3e527d9 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 18 Jul 2023 22:19:44 +0530 Subject: [PATCH] Add support for different compilation paths for DocuChat (#1665) --- .gitignore | 4 + apps/language_models/langchain/README.md | 6 +- apps/language_models/langchain/gen.py | 2 +- .../langchain/h2oai_pipeline.py | 74 +++++++++++++++---- .../stable_diffusion/src/utils/stable_args.py | 10 +++ apps/stable_diffusion/web/ui/h2ogpt.py | 2 + 6 files changed, 81 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index eeb217e2b6..f1bf381809 100644 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,7 @@ apps/stable_diffusion/web/models/ # Stencil annotators. stencil_annotator/ + +# For DocuChat +apps/language_models/langchain/user_path/ +db_dir_UserData diff --git a/apps/language_models/langchain/README.md b/apps/language_models/langchain/README.md index 59c3f9b2d3..af89b8f4ec 100644 --- a/apps/language_models/langchain/README.md +++ b/apps/language_models/langchain/README.md @@ -6,10 +6,12 @@ ```shell pip install -r apps/language_models/langchain/langchain_requirements.txt ``` -2.) Create a folder named `user_path` and all your docs into that folder. + +2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory. + Now, you are ready to use the model. 3.) To run the model, run the following command: ```shell -python apps/language_models/langchain/gen.py --user_path= --cli=True +python apps/language_models/langchain/gen.py --cli=True ``` diff --git a/apps/language_models/langchain/gen.py b/apps/language_models/langchain/gen.py index e7689b122f..327717748a 100644 --- a/apps/language_models/langchain/gen.py +++ b/apps/language_models/langchain/gen.py @@ -177,7 +177,7 @@ def main( LangChainAction.SUMMARIZE_MAP.value, ], document_choice: list = [DocumentChoices.All_Relevant.name], - user_path: str = None, + user_path: str = "apps/language_models/langchain/user_path/", detect_user_path_changes_every_query: bool = False, load_db_if_exists: bool = True, keep_sources_in_context: bool = False, diff --git a/apps/language_models/langchain/h2oai_pipeline.py b/apps/language_models/langchain/h2oai_pipeline.py index fa580b0ff7..9088d30c2f 100644 --- a/apps/language_models/langchain/h2oai_pipeline.py +++ b/apps/language_models/langchain/h2oai_pipeline.py @@ -1,4 +1,5 @@ import os +from apps.stable_diffusion.src.utils.utils import _compile_module from transformers import TextGenerationPipeline from transformers.pipelines.text_generation import ReturnType @@ -19,34 +20,79 @@ from pathlib import Path from shark.shark_inference import SharkInference from shark.shark_downloader import download_public_file +from apps.stable_diffusion.src import args global_device = "cuda" global_precision = "fp16" +if not args.run_docuchat_web: + args.device = global_device + args.precision = global_precision + class H2OGPTSHARKModel(torch.nn.Module): def __init__(self): super().__init__() model_name = "h2ogpt_falcon_7b" path_str = ( - model_name + "_" + global_precision + "_" + global_device + ".vmfb" + model_name + "_" + args.precision + "_" + args.device + ".vmfb" ) vmfb_path = Path(path_str) + path_str = model_name + "_" + args.precision + ".mlir" + mlir_path = Path(path_str) + shark_module = None if not vmfb_path.exists(): - # Downloading VMFB from shark_tank - print("Downloading vmfb from shark tank.") - download_public_file( - "gs://shark_tank/langchain/" + path_str, - vmfb_path.absolute(), - single_file=True, - ) - print("Compiled vmfb found. Loading it from: ", vmfb_path) - shark_module = SharkInference( - None, device=global_device, mlir_dialect="linalg" - ) - shark_module.load_module(vmfb_path) - print("Compiled vmfb loaded successfully.") + if args.device == "cuda" and args.precision in ["fp16", "fp32"]: + # Downloading VMFB from shark_tank + print("Downloading vmfb from shark tank.") + download_public_file( + "gs://shark_tank/langchain/" + path_str, + vmfb_path.absolute(), + single_file=True, + ) + else: + if mlir_path.exists(): + with open(mlir_path, "rb") as f: + bytecode = f.read() + else: + # Downloading MLIR from shark_tank + download_public_file( + "gs://shark_tank/langchain/" + + model_name + + "_" + + args.precision + + ".mlir", + mlir_path.absolute(), + single_file=True, + ) + if mlir_path.exists(): + with open(mlir_path, "rb") as f: + bytecode = f.read() + else: + raise ValueError( + f"MLIR not found at {mlir_path.absolute()}" + " after downloading! Please check path and try again" + ) + shark_module = SharkInference( + mlir_module=bytecode, + device=args.device, + mlir_dialect="linalg", + ) + print(f"[DEBUG] generating vmfb.") + shark_module = _compile_module(shark_module, vmfb_path, []) + print("Saved newly generated vmfb.") + + if shark_module is None: + if vmfb_path.exists(): + print("Compiled vmfb found. Loading it from: ", vmfb_path) + shark_module = SharkInference( + None, device=global_device, mlir_dialect="linalg" + ) + shark_module.load_module(vmfb_path) + print("Compiled vmfb loaded successfully.") + else: + raise ValueError("Unable to download/generate a vmfb.") self.model = shark_module diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index 336e94bea4..483cfa42a3 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -648,6 +648,16 @@ def is_valid_file(arg): help="Op to be optimized, options are matmul, bmm, conv and all.", ) +############################################################################## +# DocuChat Flags +############################################################################## + +p.add_argument( + "--run_docuchat_web", + default=False, + action=argparse.BooleanOptionalAction, + help="Specifies whether the docuchat's web version is running or not.", +) args, unknown = p.parse_known_args() if args.import_debug: diff --git a/apps/stable_diffusion/web/ui/h2ogpt.py b/apps/stable_diffusion/web/ui/h2ogpt.py index 17489ef0cb..9010a15c32 100644 --- a/apps/stable_diffusion/web/ui/h2ogpt.py +++ b/apps/stable_diffusion/web/ui/h2ogpt.py @@ -12,6 +12,7 @@ LangChainAction, ) import apps.language_models.langchain.gen as gen +from apps.stable_diffusion.src import args def user(message, history): @@ -80,6 +81,7 @@ def create_prompt(model_name, history): def chat(curr_system_message, history, model, device, precision): + args.run_docuchat_web = True global sharded_model global past_key_values global h2ogpt_model