Skip to content

Commit

Permalink
Merge pull request vanna-ai#391 from jis478/main
Browse files Browse the repository at this point in the history
huggingface model support added
  • Loading branch information
zainhoda authored Apr 30, 2024
2 parents 994d3e0 + 6bb70ac commit 5c9af52
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -45,3 +45,4 @@ ollama = ["ollama", "httpx"]
qdrant = ["qdrant-client"]
vllm = ["vllm"]
opensearch = ["opensearch-py", "opensearch-dsl"]
hf = ["transformers"]
1 change: 1 addition & 0 deletions src/vanna/hf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .hf import Hf
79 changes: 79 additions & 0 deletions src/vanna/hf/hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import re
from transformers import AutoTokenizer, AutoModelForCausalLM

from ..base import VannaBase


class Hf(VannaBase):
def __init__(self, config=None):
model_name = self.config.get(
"model_name", None
) # e.g. meta-llama/Meta-Llama-3-8B-Instruct
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def extract_sql_query(self, text):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- text (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

match = pattern.search(text)
if match:
# Remove three backticks from the matched string if they exist
return match.group(0).replace("```", "")
else:
return text

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)

# Replace "\_" with "_"
sql = sql.replace("\\_", "_")

sql = sql.replace("\\", "")

return self.extract_sql_query(sql)

def submit_prompt(self, prompt, **kwargs) -> str:

input_ids = self.tokenizer.apply_chat_template(
prompt, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)

outputs = self.model.generate(
input_ids,
max_new_tokens=512,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=True,
temperature=1,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1] :]
response = self.tokenizer.decode(response, skip_special_tokens=True)
self.log(response)

return response
2 changes: 2 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ def test_regular_imports():
from vanna.anthropic.anthropic_chat import Anthropic_Chat
from vanna.base.base import VannaBase
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.hf.hf import Hf
from vanna.local import LocalContext_OpenAI
from vanna.marqo.marqo import Marqo_VectorStore
from vanna.mistral.mistral import Mistral
Expand All @@ -20,6 +21,7 @@ def test_shortcut_imports():
from vanna.anthropic import Anthropic_Chat
from vanna.base import VannaBase
from vanna.chromadb import ChromaDB_VectorStore
from vanna.hf import Hf
from vanna.marqo import Marqo_VectorStore
from vanna.mistral import Mistral
from vanna.ollama import Ollama
Expand Down

0 comments on commit 5c9af52

Please sign in to comment.