Skip to content

Commit

Permalink
Add support for different compilation paths for DocuChat (nod-ai#1665)
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 authored Jul 18, 2023
1 parent 11f62d7 commit b013659
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 17 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,7 @@ apps/stable_diffusion/web/models/

# Stencil annotators.
stencil_annotator/

# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
6 changes: 4 additions & 2 deletions apps/language_models/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
```shell
pip install -r apps/language_models/langchain/langchain_requirements.txt
```
2.) Create a folder named `user_path` and all your docs into that folder.

2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.

Now, you are ready to use the model.

3.) To run the model, run the following command:
```shell
python apps/language_models/langchain/gen.py --user_path=<path_to_user_path_directory> --cli=True
python apps/language_models/langchain/gen.py --cli=True
```
2 changes: 1 addition & 1 deletion apps/language_models/langchain/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def main(
LangChainAction.SUMMARIZE_MAP.value,
],
document_choice: list = [DocumentChoices.All_Relevant.name],
user_path: str = None,
user_path: str = "apps/language_models/langchain/user_path/",
detect_user_path_changes_every_query: bool = False,
load_db_if_exists: bool = True,
keep_sources_in_context: bool = False,
Expand Down
74 changes: 60 additions & 14 deletions apps/language_models/langchain/h2oai_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from apps.stable_diffusion.src.utils.utils import _compile_module

from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
Expand All @@ -19,34 +20,79 @@
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from apps.stable_diffusion.src import args

global_device = "cuda"
global_precision = "fp16"

if not args.run_docuchat_web:
args.device = global_device
args.precision = global_precision


class H2OGPTSHARKModel(torch.nn.Module):
def __init__(self):
super().__init__()
model_name = "h2ogpt_falcon_7b"
path_str = (
model_name + "_" + global_precision + "_" + global_device + ".vmfb"
model_name + "_" + args.precision + "_" + args.device + ".vmfb"
)
vmfb_path = Path(path_str)
path_str = model_name + "_" + args.precision + ".mlir"
mlir_path = Path(path_str)
shark_module = None

if not vmfb_path.exists():
# Downloading VMFB from shark_tank
print("Downloading vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + path_str,
vmfb_path.absolute(),
single_file=True,
)
print("Compiled vmfb found. Loading it from: ", vmfb_path)
shark_module = SharkInference(
None, device=global_device, mlir_dialect="linalg"
)
shark_module.load_module(vmfb_path)
print("Compiled vmfb loaded successfully.")
if args.device == "cuda" and args.precision in ["fp16", "fp32"]:
# Downloading VMFB from shark_tank
print("Downloading vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + path_str,
vmfb_path.absolute(),
single_file=True,
)
else:
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/"
+ model_name
+ "_"
+ args.precision
+ ".mlir",
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
raise ValueError(
f"MLIR not found at {mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(shark_module, vmfb_path, [])
print("Saved newly generated vmfb.")

if shark_module is None:
if vmfb_path.exists():
print("Compiled vmfb found. Loading it from: ", vmfb_path)
shark_module = SharkInference(
None, device=global_device, mlir_dialect="linalg"
)
shark_module.load_module(vmfb_path)
print("Compiled vmfb loaded successfully.")
else:
raise ValueError("Unable to download/generate a vmfb.")

self.model = shark_module

Expand Down
10 changes: 10 additions & 0 deletions apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,16 @@ def is_valid_file(arg):
help="Op to be optimized, options are matmul, bmm, conv and all.",
)

##############################################################################
# DocuChat Flags
##############################################################################

p.add_argument(
"--run_docuchat_web",
default=False,
action=argparse.BooleanOptionalAction,
help="Specifies whether the docuchat's web version is running or not.",
)

args, unknown = p.parse_known_args()
if args.import_debug:
Expand Down
2 changes: 2 additions & 0 deletions apps/stable_diffusion/web/ui/h2ogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
LangChainAction,
)
import apps.language_models.langchain.gen as gen
from apps.stable_diffusion.src import args


def user(message, history):
Expand Down Expand Up @@ -80,6 +81,7 @@ def create_prompt(model_name, history):


def chat(curr_system_message, history, model, device, precision):
args.run_docuchat_web = True
global sharded_model
global past_key_values
global h2ogpt_model
Expand Down

0 comments on commit b013659

Please sign in to comment.