Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions llama-index-core/llama_index/core/indices/struct_store/sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,17 @@ async def _aquery(self, query_bundle: QueryBundle) -> Response:


def _validate_prompt(
custom_prompt: BasePromptTemplate,
default_prompt: BasePromptTemplate,
response_synthesis_prompt: BasePromptTemplate,
default_synthesis_prompt: BasePromptTemplate,
) -> None:
"""Validate prompt."""
if custom_prompt.template_vars != default_prompt.template_vars:
if (
response_synthesis_prompt.template_vars
!= default_synthesis_prompt.template_vars
):
raise ValueError(
"custom_prompt must have the following template variables: "
f"{default_prompt.template_vars}"
"response_synthesis_prompt must have the following template variables: "
f"{default_synthesis_prompt.template_vars}"
)


Expand All @@ -330,6 +333,7 @@ def __init__(
verbose: bool = False,
# deprecated
service_context: Optional[ServiceContext] = None,
skip_table_verification: bool = False,
**kwargs: Any,
) -> None:
"""Initialize params."""
Expand All @@ -351,6 +355,7 @@ def __init__(

self._synthesize_response = synthesize_response
self._verbose = verbose
self._skip_table_verification = skip_table_verification
super().__init__(
callback_manager=callback_manager
or callback_manager_from_settings_or_context(Settings, service_context),
Expand Down Expand Up @@ -383,7 +388,7 @@ def service_context(self) -> Optional[ServiceContext]:
def _query(self, query_bundle: QueryBundle) -> Response:
"""Answer a query."""
retrieved_nodes, metadata = self.sql_retriever.retrieve_with_metadata(
query_bundle
query_bundle, self._skip_table_verification
)

sql_query_str = metadata["sql_query"]
Expand Down Expand Up @@ -459,11 +464,13 @@ def __init__(
response_synthesis_prompt: Optional[BasePromptTemplate] = None,
refine_synthesis_prompt: Optional[BasePromptTemplate] = None,
tables: Optional[Union[List[str], List[Table]]] = None,
service_context: Optional[ServiceContext] = None,
service_context_text_to_sql: Optional[ServiceContext] = None,
service_context_synthesis: Optional[ServiceContext] = None,
context_str_prefix: Optional[str] = None,
sql_only: bool = False,
callback_manager: Optional[CallbackManager] = None,
sql_only: bool = False,
verbose: bool = False,
skip_table_verification: bool = False,
**kwargs: Any,
) -> None:
"""Initialize params."""
Expand All @@ -475,7 +482,7 @@ def __init__(
context_query_kwargs=context_query_kwargs,
tables=tables,
context_str_prefix=context_str_prefix,
service_context=service_context,
service_context=service_context_text_to_sql,
sql_only=sql_only,
callback_manager=callback_manager,
verbose=verbose,
Expand All @@ -485,9 +492,10 @@ def __init__(
response_synthesis_prompt=response_synthesis_prompt,
refine_synthesis_prompt=refine_synthesis_prompt,
llm=llm,
service_context=service_context,
service_context=service_context_synthesis,
callback_manager=callback_manager,
verbose=verbose,
skip_table_verification=skip_table_verification,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,20 @@ def _load_get_tables_fn(
return lambda _: table_schemas

def retrieve_with_metadata(
self, str_or_query_bundle: QueryType
self, str_or_query_bundle: QueryType, skip_table_verification: bool
) -> Tuple[List[NodeWithScore], Dict]:
"""Retrieve with metadata."""
if isinstance(str_or_query_bundle, str):
query_bundle = QueryBundle(str_or_query_bundle)
else:
query_bundle = str_or_query_bundle
table_desc_str = self._get_table_context(query_bundle)
logger.info(f"> Table desc str: {table_desc_str}")
if self._verbose:
print(f"> Table desc str: {table_desc_str}")

table_desc_str = ""
if not skip_table_verification:
table_desc_str = self._get_table_context(query_bundle)
logger.info(f"> Table desc str: {table_desc_str}")
if self._verbose:
print(f"> Table desc str: {table_desc_str}")

response_str = self._llm.predict(
self._text_to_sql_prompt,
Expand Down
4 changes: 2 additions & 2 deletions llama-index-core/llama_index/core/postprocessor/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(
if tokenizer_fn is None:
import nltk.data

tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")
tokenizer_fn = tokenizer.tokenize
tokenizer = nltk.data.load("tokenizers/punkt/english.pickle", quiet=True)
tokenizer_fn = tokenizer.tokenize
self._tokenizer_fn = tokenizer_fn

super().__init__(
Expand Down
67 changes: 43 additions & 24 deletions llama-index-core/llama_index/core/utilities/sql_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""SQL wrapper around SQLDatabase in langchain."""
from typing import Any, Dict, Iterable, List, Optional, Tuple

from typing import Any, Dict, Iterable, List, Optional, Tuple, Callable

from sqlalchemy import MetaData, create_engine, insert, inspect, text
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -49,21 +50,31 @@ def __init__(
custom_table_info: Optional[dict] = None,
view_support: bool = False,
max_string_length: int = 300,
preprocess_query_function: Callable[..., Any] = None,
skip_table_verification: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
self._schema = schema
self._skip_table_verification = skip_table_verification
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")

self._inspector = inspect(self._engine)

# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=schema)
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
)
if self._skip_table_verification:
self._all_tables = set(include_tables)
else:
self._all_tables = set(
self._inspector.get_table_names(schema=schema)
+ (
self._inspector.get_view_names(schema=schema)
if view_support
else []
)
)

self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
Expand Down Expand Up @@ -107,12 +118,20 @@ def __init__(

self._metadata = metadata or MetaData()
# including view support if view_support = true
self._metadata.reflect(
views=view_support,
bind=self._engine,
only=list(self._usable_tables),
schema=self._schema,
)
if self._skip_table_verification:
self._metadata.reflect(
views=view_support,
bind=self._engine,
schema=self._schema,
)
else:
self._metadata.reflect(
views=view_support,
bind=self._engine,
only=list(self._usable_tables),
schema=self._schema,
)
self._preprocess_query_function = preprocess_query_function

@property
def engine(self) -> Engine:
Expand Down Expand Up @@ -150,19 +169,10 @@ def get_table_columns(self, table_name: str) -> List[Any]:
def get_single_table_info(self, table_name: str) -> str:
"""Get table info for a single table."""
# same logic as table_info, but with specific table names
template = "Table '{table_name}' has columns: {columns}, "
try:
# try to retrieve table comment
table_comment = self._inspector.get_table_comment(
table_name, schema=self._schema
)["text"]
if table_comment:
template += f"with comment: ({table_comment}) "
except NotImplementedError:
# get_table_comment raises NotImplementedError for a dialect that does not support comments.
pass

template += "and foreign keys: {foreign_keys}."
template = (
"Table '{table_name}' has columns: {columns}, "
"and foreign keys: {foreign_keys}."
)
columns = []
for column in self._inspector.get_columns(table_name, schema=self._schema):
if column.get("comment"):
Expand Down Expand Up @@ -207,12 +217,21 @@ def truncate_word(self, content: Any, *, length: int, suffix: str = "...") -> st

return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix

def _preprocess_query(self, query: str) -> str:
"""Preprocess the SQL query."""
if self._preprocess_query_function:
return self._preprocess_query_function(query)
return (
query # Return the query unchanged if no preprocessing function is provided
)

def run_sql(self, command: str) -> Tuple[str, Dict]:
"""Execute a SQL statement and return a string representing the results.

If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
command = self._preprocess_query(command)
with self._engine.begin() as connection:
try:
if self._schema:
Expand Down
16 changes: 9 additions & 7 deletions llama-index-core/llama_index/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ def __init__(self) -> None:
try:
nltk.data.find("corpora/stopwords", paths=[self._nltk_data_dir])
except LookupError:
nltk.download("stopwords", download_dir=self._nltk_data_dir)

nltk.download("stopwords", download_dir=self._nltk_data_dir, quiet=True)
try:
nltk.data.find("tokenizers/punkt", paths=[self._nltk_data_dir])
except LookupError:
nltk.download("punkt", download_dir=self._nltk_data_dir)
nltk.download("punkt", download_dir=self._nltk_data_dir, quiet=True)

@property
def stopwords(self) -> List[str]:
Expand All @@ -73,15 +72,19 @@ def stopwords(self) -> List[str]:
try:
import nltk
from nltk.corpus import stopwords

nltk.download("wordnet", quiet=True)
except ImportError:
raise ImportError(
"`nltk` package not found, please run `pip install nltk`"
)

try:
nltk.data.find("corpora/stopwords", paths=[self._nltk_data_dir])
nltk.data.find(
"corpora/stopwords", paths=[self._nltk_data_dir], quiet=True
)
except LookupError:
nltk.download("stopwords", download_dir=self._nltk_data_dir)
nltk.download("stopwords", download_dir=self._nltk_data_dir, quiet=True)
self._stopwords = stopwords.words("english")
return self._stopwords

Expand All @@ -92,8 +95,7 @@ def stopwords(self) -> List[str]:
# Global Tokenizer
@runtime_checkable
class Tokenizer(Protocol):
def encode(self, text: str, *args: Any, **kwargs: Any) -> List[Any]:
...
def encode(self, text: str, *args: Any, **kwargs: Any) -> List[Any]: ...


def set_global_tokenizer(tokenizer: Union[Tokenizer, Callable[[str], list]]) -> None:
Expand Down
49 changes: 38 additions & 11 deletions llama-index-legacy/llama_index/legacy/utilities/sql_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""SQL wrapper around SQLDatabase in langchain."""
from typing import Any, Dict, Iterable, List, Optional, Tuple

from typing import Any, Dict, Iterable, List, Optional, Tuple, Callable

from sqlalchemy import MetaData, create_engine, insert, inspect, text
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -49,6 +50,8 @@ def __init__(
custom_table_info: Optional[dict] = None,
view_support: bool = False,
max_string_length: int = 300,
preprocess_query_function: Callable[..., Any] = None,
skip_table_verification: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
Expand All @@ -60,10 +63,17 @@ def __init__(

# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=schema)
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
)
if skip_table_verification:
self._all_tables = set(include_tables)
else:
self._all_tables = set(
self._inspector.get_table_names(schema=schema)
+ (
self._inspector.get_view_names(schema=schema)
if view_support
else []
)
)

self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
Expand Down Expand Up @@ -107,12 +117,20 @@ def __init__(

self._metadata = metadata or MetaData()
# including view support if view_support = true
self._metadata.reflect(
views=view_support,
bind=self._engine,
only=list(self._usable_tables),
schema=self._schema,
)
if skip_table_verification:
self._metadata.reflect(
views=view_support,
bind=self._engine,
schema=self._schema,
)
else:
self._metadata.reflect(
views=view_support,
bind=self._engine,
only=list(self._usable_tables),
schema=self._schema,
)
self.preprocess_query_function = preprocess_query_function

@property
def engine(self) -> Engine:
Expand Down Expand Up @@ -198,12 +216,21 @@ def truncate_word(self, content: Any, *, length: int, suffix: str = "...") -> st

return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix

def _preprocess_query(self, query: str) -> str:
"""Preprocess the SQL query."""
if self.preprocess_query_function:
return self.preprocess_query_function(query)
return (
query # Return the query unchanged if no preprocessing function is provided
)

def run_sql(self, command: str) -> Tuple[str, Dict]:
"""Execute a SQL statement and return a string representing the results.

If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
command = self._preprocess_query(command)
with self._engine.begin() as connection:
try:
if self._schema:
Expand Down