diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 87c45c7..1c22e89 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -10,7 +10,7 @@ name: Upload Python Package on: push: - branches: [main, development] + branches: [main, staging] permissions: contents: read diff --git a/.gitignore b/.gitignore index 02f52fc..df679ed 100644 --- a/.gitignore +++ b/.gitignore @@ -199,6 +199,7 @@ pyrightconfig.json ### Vector Store Instances ### chroma.db test_chroma.db +.database/* ### Mac .DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 86730de..f56ae9d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,11 @@ +# Ragstar - LLM Tools for DBT Projects -# Ragstar +Ragstar (inspired by `RAG & select *`) is set of LLM powered tools to elevate your dbt projects and supercharge your data team. -Ragstar (inspired by `RAG & select *`) is a tool that enables you to ask ChatGPT questions about your dbt project. +These tools include: + +- Chatbot: ask questions about data and get answers based on your dbt model documentation +- Documentation Generator: generate documentation for dbt models based on model and upstream model definition. ## Get Started @@ -13,9 +17,9 @@ Ragstar can be installed via pip. pip install ragstar ``` -### Basic Usage +## Basic Usage - Chatbot -How to multiply one number by another with this lib: +How to load your dbt project into the Chatbot and ask questions about your data. ```Python from ragstar import Chatbot @@ -31,21 +35,14 @@ chatbot.load_models() # Step 2. Ask the chatbot a question response = chatbot.ask_question( - 'How can I obtain the number of customers who upgraded to a paid plan in the last 3 months?' + 'How can I obtain the number of customers who upgraded to a paid plan in the last 3 months?' ) print(response) - -# Step 3. Clear your local database (Optional). -# You only need to do this if you would like to load a different project into your db -# or restart from scratch for whatever reason. - -# If you make any changes to your existing models and load them again, they get upserted into the database. -chatbot.reset_model_db() ``` **Note**: Ragstar currently only supports OpenAI ChatGPT models for generating embeddings and responses to queries. -## How it works +### How it works Ragstar is based on the concept of Retrieval Augmented Generation and basically works as follows: @@ -55,10 +52,32 @@ Ragstar is based on the concept of Retrieval Augmented Generation and basically - These models are then fed into ChatGPT as a prompt, along with some basic instructions and your question. - The response is returned to you as a string. +## Basic Usage - Documentation Generator + +How to load your dbt project into the Documentation Generator and have it write documentation for your models. + +```Python +from ragstar import DocumentationGenerator + +# Instantiate a Documentation Generator object +doc_gen = DocumentationGenerator( + dbt_project_root="YOUR_DBT_PROJECT_PATH", + openai_api_key="YOUR_OPENAI_API_KEY", +) + +# Generate documentation for a model and all its upstream models +doc_gen.generate_documentation( + model_name='dbt_model_name', + write_documentation_to_yaml=False +) +``` + ## Advanced Usage + You can control the behaviour of some of the class member functions in more detail, or inspect the underlying classes for more functionality. The Chatbot is composed of two classes: + - Vector Store - DBT Project - Composed of DBT Model @@ -66,6 +85,7 @@ The Chatbot is composed of two classes: Here are the classes and methods they expose: ### Chatbot + A class representing a chatbot that allows users to ask questions about dbt models. Attributes: @@ -83,7 +103,8 @@ A class representing a chatbot that allows users to ask questions about dbt mode ### Methods -#### __init__ +#### **init** + Initializes a chatbot object along with a default set of instructions. Args: @@ -94,15 +115,16 @@ Initializes a chatbot object along with a default set of instructions. Defaults to "text-embedding-3-large". chatbot_model (str, optional): The name of the OpenAI chatbot model to be used. - Defaults to "gpt-4-turbo-preview". + Defaults to "gpt-4-turbo-preview". - db_persist_path (str, optional): The path to the persistent database file. - Defaults to "./chroma.db". + db_persist_path (str, optional): The path to the persistent database file. + Defaults to "./chroma.db". Returns: None #### load_models + Upsert the set of models that will be available to your chatbot into a vector store. The chatbot will only be able to use these models to answer questions and nothing else. The default behavior is to load all models in the dbt project, but you can specify a subset of models, included folders or excluded folders to customize the set of models that will be available to the chatbot. @@ -137,12 +159,14 @@ This will reset and remove all the models from the vector store. You'll need to None #### get_instructions + Get the instructions being used to tune the chatbot. Returns: list[str]: A list of instructions being used to tune the chatbot. #### set_instructions + Set the instructions for the chatbot. Args: @@ -150,7 +174,9 @@ Set the instructions for the chatbot. Returns: None + #### set_embedding_model + Set the embedding model for the vector store. Args: @@ -158,8 +184,9 @@ Set the embedding model for the vector store. Returns: None - + #### set_chatbot_model + Set the chatbot model for the chatbot. Args: @@ -169,9 +196,10 @@ Set the chatbot model for the chatbot. None ## Appendices + These are the underlying classes that are used to compose the functionality of the chatbot. -### Vector Store +### Vector Store A class representing a vector store for dbt models. @@ -181,16 +209,18 @@ A class representing a vector store for dbt models. reset_collection: Clear the collection of all documents. ### DBT Project - A class representing a DBT project yaml parser. + +A class representing a DBT project yaml parser. Attributes: project_root (str): Absolute path to the root of the dbt project being parsed ### DBT Model + A class representing a dbt model. Attributes: name (str): The name of the model. description (str, optional): The description of the model. columns (list[DbtModelColumn], optional): A list of columns contained in the model. - May or may not be exhaustive. + May or may not be exhaustive. diff --git a/ragstar/__init__.py b/ragstar/__init__.py index 7e2b69e..b891f2a 100644 --- a/ragstar/__init__.py +++ b/ragstar/__init__.py @@ -1,4 +1,17 @@ +from ragstar.types import ( + PromptMessage, + ParsedSearchResult, + DbtModelDict, + DbtModelDirectoryEntry, +) + +from ragstar.instructions import ( + INTERPRET_MODEL_INSTRUCTIONS, + ANSWER_QUESTION_INSTRUCTIONS, +) + from ragstar.dbt_model import DbtModel from ragstar.dbt_project import DbtProject from ragstar.vector_store import VectorStore from ragstar.chatbot import Chatbot +from ragstar.documentation_generator import DocumentationGenerator diff --git a/ragstar/chatbot.py b/ragstar/chatbot.py index 0a104ff..53428df 100644 --- a/ragstar/chatbot.py +++ b/ragstar/chatbot.py @@ -2,6 +2,7 @@ from ragstar.types import PromptMessage, ParsedSearchResult +from ragstar.instructions import ANSWER_QUESTION_INSTRUCTIONS from ragstar.dbt_project import DbtProject from ragstar.vector_store import VectorStore @@ -30,7 +31,8 @@ def __init__( openai_api_key: str, embedding_model: str = "text-embedding-3-large", chatbot_model: str = "gpt-4-turbo-preview", - db_persist_path: str = "./chroma.db", + vector_db_path: str = "./database/chroma.db", + database_path: str = "./database/directory.json", ) -> None: """ Initializes a chatbot object along with a default set of instructions. @@ -53,30 +55,17 @@ def __init__( self.__chatbot_model: str = chatbot_model self.__openai_api_key: str = openai_api_key - self.project: DbtProject = DbtProject(dbt_project_root) + self.project: DbtProject = DbtProject( + dbt_project_root=dbt_project_root, database_path=database_path + ) + self.store: VectorStore = VectorStore( - openai_api_key, embedding_model, db_persist_path + openai_api_key, embedding_model, vector_db_path ) - self.__instructions: list[str] = [ - "You are a data analyst working with a data warehouse.", - "You should provide the user with the information they need to answer their question.", - "You should only provide information that you are confident is correct.", - "When you are not sure about the answer, you should let the user know.", - "If you are able to construct a SQL query that would answer the user's question, you should do so.", - "However please refrain from doing so if the user's question is ambiguous or unclear.", - "When writing a SQL query, you should only use column values if these values have been explicitly" - + " provided to you in the information you have been given.", - "Do not write a SQL query if you are unsure about the correctness of the query or" - + " about the values contained in the columns.", - "Only write a SQL query if you are confident that the query is exhaustive" - + " and that it will return the correct results.", - "If it is not possible to write a SQL that fulfils these conditions, you should instead respond" - + " with the names of the tables or columns that you think are relevant to the user's question.", - "You should also refrain from providing any information that is not directly related to the" - + " user's question or that which cannot be inferred from the information you have been given.", - "The following information about tables and columns is available to you:", - ] + self.client = OpenAI(api_key=self.__openai_api_key) + + self.__instructions: list[str] = [ANSWER_QUESTION_INSTRUCTIONS] def __prepare_prompt( self, closest_models: list[ParsedSearchResult], query: str @@ -186,7 +175,7 @@ def reset_model_db(self) -> None: """ self.store.reset_collection() - def ask_question(self, query: str, get_models_name_only: bool = False) -> str: + def ask_question(self, query: str, get_model_names_only: bool = False) -> str: """ Ask the chatbot a question about your dbt models and get a response. The chatbot looks the dbt models most similar to the user query and uses them to answer the question. @@ -204,7 +193,7 @@ def ask_question(self, query: str, get_models_name_only: bool = False) -> str: closest_models = self.store.query_collection(query) model_names = ", ".join(map(lambda x: x["id"], closest_models)) - if get_models_name_only: + if get_model_names_only: return model_names print("Closest models found:", model_names) @@ -212,10 +201,8 @@ def ask_question(self, query: str, get_models_name_only: bool = False) -> str: print("\nPreparing prompt...") prompt = self.__prepare_prompt(closest_models, query) - client = OpenAI(api_key=self.__openai_api_key) - print("\nCalculating response...") - completion = client.chat.completions.create( + completion = self.client.chat.completions.create( model=self.__chatbot_model, messages=prompt, ) diff --git a/ragstar/dbt_project.py b/ragstar/dbt_project.py index 9f8d87d..0afb168 100644 --- a/ragstar/dbt_project.py +++ b/ragstar/dbt_project.py @@ -1,27 +1,46 @@ import os import glob +import re +import json + +from typing import Union + import yaml +from ragstar.types import DbtModelDirectoryEntry, DbtProjectDirectory from ragstar.dbt_model import DbtModel +SOURCE_SEARCH_EXPRESSION = r"source\(['\"]*(.*?)['\"]*?\)" +REF_SEARCH_EXPRESSION = r"ref\(['\"]*(.*?)['\"]*\)" + class DbtProject: """ - A class representing a DBT project yaml parser. - - Attributes: - project_root (str): Absolute path to the root of the dbt project being parsed + A class representing a DBT project. """ - def __init__(self, project_root: str) -> None: + def __init__( + self, + dbt_project_root: str, + database_path: str = "./directory.json", + ) -> None: """ Initializes a dbt project parser object. Args: - project_root (str): Root of the dbt prject + dbt_project_root (str): Root of the dbt prject + database_path (str, optional): Path to the directory file that stores the parsed dbt project. + + Methods: + parse: Parse the dbt project and store details in a manifest file. + get_single_model: Get a single model by name. + get_models: Get a list of models based on the provided filters. + update_model_directory: Update a model in the directory. """ - self.__project_root = project_root - project_file = os.path.join(project_root, "dbt_project.yml") + self.__project_root = dbt_project_root + self.__directory_path = database_path + + project_file = os.path.join(dbt_project_root, "dbt_project.yml") if not os.path.isfile(project_file): raise Exception("No dbt project found in the specified folder") @@ -34,70 +53,268 @@ def __init__(self, project_root: str) -> None: self.__model_paths = project_config.get("model-paths", ["models"]) + self.__sql_files = self.__get_all_files("sql") + self.__yaml_files = self.__get_all_files("yml") + + def __get_all_files(self, file_extension: str): + """ + Get all files of a certain type in the dbt project. + + Args: + file_extension (str): The file extension to search for. + + Returns: + list: A list of files with the specified extension. + """ + files = [] + + for path in self.__model_paths: + files.extend( + glob.glob( + os.path.join( + self.__project_root, path, "**", f"*.{file_extension}" + ), + recursive=True, + ) + ) + + return files + + def __find_upstream_references( + self, file_path: str, recursive: bool = False, dependencies: list[str] = None + ): + """ + Find upstream references in a SQL file. + + Args: + file_path (str): The path to the SQL file. + recursive (bool, optional): Whether to recursively search for upstream references. + dependencies (list, optional): A list of dependencies to add to. + + Returns: + list: A list of upstream references. + """ + if dependencies is None: + dependencies = [] + + with open(file_path, encoding="utf-8") as f: + file_contents = f.read() + + search_results = re.findall(REF_SEARCH_EXPRESSION, file_contents) + unique_results = list(set(search_results)) + + if recursive: + for result in unique_results: + sub_file_path = next( + (x for x in self.__sql_files if x.endswith(f"{result}.sql")), None + ) + dependencies = self.__find_upstream_references( + file_path=sub_file_path, recursive=True, dependencies=dependencies + ) + + return dependencies + unique_results + + def __parse_sql_file(self, sql_file: str): + """ + Parse a SQL file and return a dictionary with the file metadata. + + Args: + sql_file (str): The path to the SQL file. + + Returns: + dict: A dictionary containing the parsed SQL file metadata. + """ + with open(sql_file, encoding="utf-8") as f: + sql_contents = f.read() + + sources = [] + source_search = re.findall(SOURCE_SEARCH_EXPRESSION, sql_contents) + + for raw_source in source_search: + source = raw_source.replace("'", "").replace('"', "").split(",") + sources.append({"name": source[0], "table": source[1]}) + + return { + "absolute_path": sql_file, + "relative_path": sql_file.replace(self.__project_root, ""), + "name": os.path.basename(sql_file).replace(".sql", ""), + "refs": self.__find_upstream_references(sql_file, False), + "deps": self.__find_upstream_references(sql_file, True), + "sources": sources, + "sql_contents": sql_contents, + } + + def __parse_yaml_files(self, yaml_files: list[str]): + """ + Extract documentation from the parsed yaml files. + + Args: + yaml_files (list): A list of yaml files to parse. + + Returns: + dict: A dictionary containing the parsed models. + dict: A dictionary containing the parsed sources. + """ + models = {} + sources = {} + + for yaml_path in yaml_files: + with open(yaml_path, encoding="utf-8") as f: + yaml_contents = yaml.safe_load(f) + + for model in yaml_contents.get("models", []): + model["yaml_path"] = yaml_path + + parsed_columns = {} + for col in model.get("columns", []): + col_name = col.pop("name") + parsed_columns[col_name] = col + + model["columns"] = parsed_columns + + models[model["name"]] = model + + for source in yaml_contents.get("sources", []): + source["yaml_path"] = yaml_path + sources[source["name"]] = source + + return models, sources + + def __get_directory(self): + """ + Get the parsed directory from the directory file. + + Returns: + dict: The parsed directory. + """ + with open(self.__directory_path, encoding="utf-8") as f: + return json.load(f) + + def __save_directory(self, directory): + """ + Save the parsed directory to a file. + + Args: + directory (dict): The directory to save. + """ + with open(self.__directory_path, "w", encoding="utf-8") as f: + json.dump(directory, f, ensure_ascii=False, indent=4) + + def parse(self) -> DbtProjectDirectory: + """ + Parse the dbt project and store details in a manifest file. + + Returns: + dict: The parsed directory. + """ + # source_sql_models = list(map(self.__parse_sql_file, self.__sql_files)) + source_sql_models = {} + + for sql_file in self.__sql_files: + parsed_model = self.__parse_sql_file(sql_file) + source_sql_models[parsed_model["name"]] = parsed_model + + documented_models, documented_sources = self.__parse_yaml_files( + self.__yaml_files + ) + + for model_name, model_dict in documented_models.items(): + yaml_path = model_dict.pop("yaml_path") + + if model_name in source_sql_models: + source_sql_models[model_name]["yaml_path"] = yaml_path + source_sql_models[model_name]["documentation"] = model_dict + else: + source_sql_models[model_name] = { + "yaml_path": yaml_path, + "documentation": model_dict, + } + + directory = { + "models": source_sql_models, + "sources": documented_sources, + } + + self.__save_directory(directory) + + return directory + + def get_single_model(self, model_name: str) -> Union[DbtModelDirectoryEntry, None]: + """ + Get a single model by name. + + Args: + model_name (str): The name of the model to get. + + Returns: + dict: The model object. + """ + if model_name is None: + raise Exception("No model name provided") + + directory = self.__get_directory() + + return directory["models"].get(model_name) + def get_models( self, models: list[str] = None, included_folders: list[str] = None, excluded_folders: list[str] = None, - ) -> list[DbtModel]: + ): """ - Scan all the YMLs in the specified folders and extract all models into a single list. + Get a list of models based on the provided filters. Args: - models (list[str], optional): A list of model names to include in the search. - - included_folders (list[str], optional): A list of paths to all folders that should be included - in model search. Paths are relative to dbt project root. - - exclude_folders (list[str], optional): A list of paths to all folders that should be excluded - in model search. Paths are relative to dbt project root. + models (list, optional): A list of model names to get. + included_folders (list, optional): A list of folders to include in the search for sql or yaml files. + excluded_folders (list, optional): A list of folders to exclude from the search for sql or yaml files. Returns: - list[DbtModel]: A list of Dbt Model objects for each model found in the included folders + list: A list of DbtModel objects. """ - parsed_models = [] - yaml_files = [] + searched_models = [] - if included_folders is None: - included_folders = self.__model_paths + directory = self.__get_directory() - for folder in included_folders: - if folder[0] == "/": - folder = folder[1:] + if models is None and included_folders is None: + searched_models = list(directory["models"].values()) - yaml_files.extend( - glob.glob( - os.path.join(self.__project_root, folder, "**", "*.yml"), - recursive=True, - ) - ) + for model in models or []: + searched_models.append(directory["models"].get(model)) - if not yaml_files: - raise Exception("No YAML files found in the specified folders") + for included_folder in included_folders or []: + for model in directory["models"].values(): + if included_folder in model.get( + "absolute_path", "" + ) or included_folder in model.get("yaml_path", ""): + searched_models.append(model) - for file in yaml_files: - should_exclude_file = False + for excluded_folder in excluded_folders or []: + for model in searched_models.copy(): + if excluded_folder in model.get( + "absolute_path", "" + ) or excluded_folder in model.get("yaml_path", ""): + searched_models.remove(model) - for excluded_folder in excluded_folders or []: - if excluded_folder in file: - should_exclude_file = True - continue + models_to_return = [] - if should_exclude_file: - continue + for model in searched_models: + if model["documentation"] is not None: + models_to_return.append(DbtModel(model["documentation"])) - with open(file, encoding="utf-8") as f: - yaml_contents = yaml.safe_load(f) + return models_to_return - if yaml_contents is None: - continue + def update_model_directory(self, model: dict): + """ + Update a model in the directory. - for model in yaml_contents.get("models", []): - if (models is not None) and (model.get("name") not in models): - continue - parsed_models.append(DbtModel(model)) + Args: + model (dict): The model to update. + """ + directory = self.__get_directory() - if not parsed_models: - raise Exception("No model ymls found in the specified folders") + if model["name"] in directory["models"]: + directory["models"][model["name"]] = model - return parsed_models + self.__save_directory(directory) diff --git a/ragstar/documentation_generator.py b/ragstar/documentation_generator.py new file mode 100644 index 0000000..323e28e --- /dev/null +++ b/ragstar/documentation_generator.py @@ -0,0 +1,217 @@ +import os +import json +import yaml + +from openai import OpenAI + +from ragstar.types import DbtModelDict, DbtModelDirectoryEntry, PromptMessage +from ragstar.instructions import INTERPRET_MODEL_INSTRUCTIONS +from ragstar.dbt_project import DbtProject + + +class MyDumper(yaml.Dumper): # pylint: disable=too-many-ancestors + """ + A custom yaml dumper that indents the yaml output like dbt does. + """ + + def increase_indent(self, flow=False, indentless=False): + return super().increase_indent(flow, False) + + +class DocumentationGenerator: + """ + A class that generates documentation for dbt models using large language models. + """ + + def __init__( + self, + dbt_project_root: str, + openai_api_key: str, + language_model: str = "gpt-4-turbo-preview", + database_path: str = "./directory.json", + ) -> None: + """ + Initializes a Documentation Generator object. + + Args: + dbt_project_root (str): Root of the dbt project + openai_api_key (str): OpenAI API key + language_model (str, optional): The language model to use for generating documentation. + Defaults to "gpt-4-turbo-preview". + database_path (str, optional): Path to the directory file that stores the parsed dbt project. + Defaults to "./directory.json". + + Attributes: + dbt_project (DbtProject): A DbtProject object representing the dbt project. + + Methods: + interpret_model: Interpret a dbt model using the language model. + generate_documentation: Generate documentation for a dbt model. + """ + self.dbt_project = DbtProject( + dbt_project_root=dbt_project_root, database_path=database_path + ) + + self.__language_model = language_model + self.__client = OpenAI(api_key=openai_api_key) + + def __get_system_prompt(self, message: str) -> PromptMessage: + """ + Get the system prompt for the language model. + + Args: + message (str): The message to include in the system prompt. + + Returns: + dict: The system prompt for the language model. + """ + return { + "role": "system", + "content": message, + } + + def __save_interpretation_to_yaml( + self, model: DbtModelDict, overwrite_existing: bool = False + ) -> None: + """ + Save the interpretation of a model to a yaml file. + + Args: + model (dict): The model to save the interpretation for. + overwrite_existing (bool, optional): Whether to overwrite the existing model + yaml documentation if it exists. Defaults to False. + """ + yaml_path = model.get("yaml_path") + + if yaml_path is not None: + if not overwrite_existing: + raise Exception( + f"Model already has documentation at {model['yaml_path']}" + ) + + with open(model["yaml_path"], "r", encoding="utf-8") as infile: + existing_yaml = yaml.load(infile, Loader=yaml.FullLoader) + existing_models = existing_yaml.get("models", []) + + search_idx = -1 + for idx, m in enumerate(existing_models): + if m["name"] == model["name"]: + search_idx = idx + + if search_idx != -1: + existing_models[search_idx] = model["interpretation"] + else: + existing_models.append(model["interpretation"]) + + existing_yaml["models"] = existing_models + yaml_content = existing_yaml + else: + model_path = model["absolute_path"] + head, tail = os.path.split(model_path) + yaml_path = os.path.join(head, "_" + tail.replace(".sql", ".yml")) + + yaml_content = {"version": 2, "models": [model["interpretation"]]} + + with open(yaml_path, "w", encoding="utf-8") as outfile: + yaml.dump( + yaml_content, + outfile, + Dumper=MyDumper, + default_flow_style=False, + sort_keys=False, + ) + + def interpret_model(self, model: DbtModelDirectoryEntry) -> DbtModelDict: + """ + Interpret a dbt model using the large language model. + + Args: + model (dict): The dbt model to interpret. + + Returns: + dict: The interpretation of the model. + """ + print(f"Interpreting model: {model['name']}") + + prompt = [] + refs = model.get("refs", []) + + prompt.append(self.__get_system_prompt(INTERPRET_MODEL_INSTRUCTIONS)) + + prompt.append( + self.__get_system_prompt( + f""" + The model you are interpreting is called {model["name"]} following is the Jinja SQL code for the model: + + {model.get("sql_contents")} + """ + ) + ) + + if len(refs) > 0: + prompt.append( + self.__get_system_prompt( + f""" + + The model {model["name"]} references the following models: {", ".join(refs)}. + The interpretation for each of these models is as follows: + """ + ) + ) + + for ref in refs: + ref_model = self.dbt_project.get_single_model(ref) + + prompt.append( + self.__get_system_prompt( + f""" + + The model {ref} is interpreted as follows: + {json.dumps(ref_model.get("interpretation"), indent=4)} + """ + ) + ) + + completion = self.__client.chat.completions.create( + model=self.__language_model, + messages=prompt, + ) + + response = ( + completion.choices[0] + .message.content.replace("```json", "") + .replace("```", "") + ) + + return json.loads(response) + + def generate_documentation( + self, model_name: str, write_documentation_to_yaml: bool = False + ) -> DbtModelDict: + """ + Generate documentation for a dbt model. + + Args: + model_name (str): The name of the model to generate documentation for. + write_documentation_to_yaml (bool, optional): Whether to save the documentation to a yaml file. + Defaults to False. + """ + model = self.dbt_project.get_single_model(model_name) + + for dep in model.get("deps", []): + dep_model = self.dbt_project.get_single_model(dep) + + if dep_model.get("interpretation") is None: + dep_model["interpretation"] = self.interpret_model(dep_model) + self.dbt_project.update_model_directory(dep_model) + + interpretation = self.interpret_model(model) + + model["interpretation"] = interpretation + + if write_documentation_to_yaml: + self.__save_interpretation_to_yaml(model) + + self.dbt_project.update_model_directory(model) + + return interpretation diff --git a/ragstar/instructions.py b/ragstar/instructions.py new file mode 100644 index 0000000..5932648 --- /dev/null +++ b/ragstar/instructions.py @@ -0,0 +1,42 @@ +INTERPRET_MODEL_INSTRUCTIONS = r""" + You are a data analyst trying to understand the meaning and schema of a dbt model. + You will be provided with the name of the model and the Jinja SQL code that defines the model. + + The Jinja files may contain references to other models, using the \{\{ ref('model_name') \}\} syntax, + or references to source tables using the \{\{ source('schema_name', 'table_name') \}\} syntax. + + The interpretation for all upstream models will be provided to you in the form of a + JSON object that contains the following keys: model, description, columns. + + A source table is a table that is not defined in the dbt project, but is instead a table that is present in the data warehouse. + + Your response should be in the form of a JSON object that contains the following keys: model, description, columns. + + The columns key should contain a list of JSON objects, each of which should contain + the following keys: name, description. + + Your response should only contain an unformatted JSON string described above and nothing else. +""" + +ANSWER_QUESTION_INSTRUCTIONS = r""" + You are a data analyst working with a data warehouse. You should provide the user with the information + they need to answer their question. + + You should only provide information that you are confident is correct. When you are not sure about the answer, + you should let the user know. + + If you are able to construct a SQL query that would answer the user's question, you should do so. However + please refrain from doing so if the user's question is ambiguous or unclear. When writing a SQL query, + you should only use column values if these values have been explicitly provided to you in the information + you have been given. + + Do not write a SQL query if you are unsure about the correctness of the query or about the values contained + in the columns. Only write a SQL query if you are confident that the query is exhaustive and that it will + return the correct results. If it is not possible to write a SQL that fulfils these conditions, + you should instead respond with the names of the tables or columns that you think are relevant to the user's question. + + You should also refrain from providing any information that is not directly related to the user's question or that + which cannot be inferred from the information you have been given. + + The following information about tables and columns is available to you: +""" diff --git a/ragstar/types.py b/ragstar/types.py index 7601b1b..42affec 100644 --- a/ragstar/types.py +++ b/ragstar/types.py @@ -1,4 +1,4 @@ -from typing import TypedDict +from typing import TypedDict, Union from typing_extensions import NotRequired @@ -18,7 +18,35 @@ class DbtModelDict(TypedDict): name: str description: NotRequired[str] - columns: list[DbtModelColumn] + columns: NotRequired[list[DbtModelColumn]] + tests: NotRequired[list[Union[dict, str]]] + config: NotRequired[dict] + yaml_path: NotRequired[str] + + +class DbtModelDirectoryEntry(TypedDict): + """ + Type for a dictionary representing an entry in a dbt model directory + """ + + absolute_path: str + relative_path: str + name: str + refs: list[str] + deps: list[str] + sources: list[str] + sql_contents: str + documentation: DbtModelDict + interpretation: DbtModelDict + + +class DbtProjectDirectory(TypedDict): + """ + Type for a dictionary representing a dbt project directory + """ + + models: dict[str, DbtModelDirectoryEntry] + sources: NotRequired[dict[str, dict]] class PromptMessage(TypedDict): diff --git a/requirements.txt b/requirements.txt index 214c597..0e4680e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ pyyaml typing_extensions +tinydb pylint chromadb openai diff --git a/setup.py b/setup.py index e49bfd7..9fec185 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( name="ragstar", - version="0.1.4", + version="0.2.1", description="RAG based LLM chatbot for dbt projects.", long_description=long_description, long_description_content_type="text/markdown", diff --git a/test/test_data/directory.json b/test/test_data/directory.json new file mode 100644 index 0000000..d1995df --- /dev/null +++ b/test/test_data/directory.json @@ -0,0 +1,84 @@ +{ + "models": { + "model_2": { + "yaml_path": "/Users/pragunbhutani/Code/ragstar/test/test_data/valid_dbt_project/models/model_2.yml", + "documentation": { + "name": "model_2", + "description": "model_2_description", + "columns": { + "col_1": { + "description": "col_1_description" + }, + "col_2": {} + } + } + }, + "model_1": { + "yaml_path": "/Users/pragunbhutani/Code/ragstar/test/test_data/valid_dbt_project/models/model_1.yml", + "documentation": { + "name": "model_1", + "description": "model_1_description", + "columns": { + "col_1": { + "description": "col_1_description" + }, + "col_2": { + "description": "col_2_description" + } + } + } + }, + "staging_1": { + "yaml_path": "/Users/pragunbhutani/Code/ragstar/test/test_data/valid_dbt_project/models/staging/schema.yml", + "documentation": { + "name": "staging_1", + "columns": { + "col_1": { + "tests": [ + "unique", + "not_null" + ] + } + } + } + }, + "staging_2": { + "yaml_path": "/Users/pragunbhutani/Code/ragstar/test/test_data/valid_dbt_project/models/staging/schema.yml", + "documentation": { + "name": "staging_2", + "columns": { + "col_with_description": { + "description": "col_with_description_description", + "tests": [ + "unique", + "not_null" + ] + }, + "col_without_description": { + "tests": [ + { + "accepted_values": { + "values": [ + "placed", + "shipped", + "completed", + "return_pending", + "returned" + ] + } + } + ] + } + } + } + }, + "intermediate_1": { + "yaml_path": "/Users/pragunbhutani/Code/ragstar/test/test_data/valid_dbt_project/models/intermediate/schema.yml", + "documentation": { + "name": "intermediate_1", + "columns": {} + } + } + }, + "sources": {} +} \ No newline at end of file diff --git a/test/test_dbt_project.py b/test/test_dbt_project.py index 5757191..6d92922 100644 --- a/test/test_dbt_project.py +++ b/test/test_dbt_project.py @@ -5,6 +5,7 @@ HERE = os.path.abspath(os.path.dirname(__file__)) VALID_PROJECT_PATH = os.path.join(HERE, "test_data/valid_dbt_project") +DATABASE_PATH = os.path.join(HERE, "test_data/directory.json") class DbtProjectTestCase(unittest.TestCase): @@ -23,7 +24,10 @@ def test_class_constructed_with_valid_project_root(self): """ Test for the case when the class is constructed with a valid project root. """ - project = DbtProject(VALID_PROJECT_PATH) + project = DbtProject( + VALID_PROJECT_PATH, + database_path=DATABASE_PATH, + ) self.assertIsInstance(project, DbtProject) @@ -31,7 +35,10 @@ def test_get_models_all_folders(self): """ Test for the case when we want to get all the models in the project. """ - project = DbtProject(VALID_PROJECT_PATH) + project = DbtProject( + VALID_PROJECT_PATH, + database_path=DATABASE_PATH, + ) models = project.get_models() self.assertEqual(len(models), 5) @@ -40,7 +47,10 @@ def test_get_models_with_included_folders(self): """ Test for the case when we want to get all the models in one/many specific folder(s). """ - project = DbtProject(VALID_PROJECT_PATH) + project = DbtProject( + VALID_PROJECT_PATH, + database_path=DATABASE_PATH, + ) models = project.get_models( included_folders=["models/staging", "models/intermediate"] ) @@ -59,7 +69,10 @@ def test_get_models_with_excluded_folders(self): Test for the case when we want to get all the models in the project, except for those in one/many specific folder(s). """ - project = DbtProject(VALID_PROJECT_PATH) + project = DbtProject( + VALID_PROJECT_PATH, + database_path=DATABASE_PATH, + ) models = project.get_models(excluded_folders=["models/intermediate"]) self.assertEqual(len(models), 4) @@ -70,7 +83,10 @@ def test_get_models_by_name(self): """ Test for the case when we want to get only specific models by name. """ - project = DbtProject(VALID_PROJECT_PATH) + project = DbtProject( + VALID_PROJECT_PATH, + database_path=DATABASE_PATH, + ) models = project.get_models(models=["staging_1", "staging_2"]) self.assertEqual(len(models), 2) diff --git a/test/test_documentation_generator.py b/test/test_documentation_generator.py new file mode 100644 index 0000000..474f7ce --- /dev/null +++ b/test/test_documentation_generator.py @@ -0,0 +1,35 @@ +import os +import unittest + +from ragstar import DocumentationGenerator + +HERE = os.path.abspath(os.path.dirname(__file__)) +VALID_PROJECT_PATH = os.path.join(HERE, "test_data/valid_dbt_project") + + +class DocumentationGeneratorTestCase(unittest.TestCase): + """ + Test cases for the DocumentationGenerator class. + """ + + def test_project_root_is_not_dbt_project(self): + """ + Test for the case when the project root is not a dbt project. + """ + with self.assertRaises(Exception): + DocumentationGenerator("invalid_path", "api_key") + + def test_class_constructed_with_valid_project_root(self): + """ + Test for the case when the class is constructed with a valid project root. + """ + generator = DocumentationGenerator(VALID_PROJECT_PATH, "api_key") + + self.assertIsInstance(generator, DocumentationGenerator) + + def test_class_constructed_without_api_key(self): + """ + Test for the case when the class is constructed without an api key. + """ + with self.assertRaises(Exception): + DocumentationGenerator(VALID_PROJECT_PATH, None)