From 53be34ab67904384e85a7a113b757e2fbf9a8ad1 Mon Sep 17 00:00:00 2001
From: Pablo Minue
Date: Sun, 7 Jul 2024 15:44:27 +0200
Subject: [PATCH] Added get_queries_from_path function, applied some linting to
the code
---
pysqltools/__init__.py | 6 +--
pysqltools/src/SQL/__init__.py | 5 ---
pysqltools/src/__init__.py | 6 +--
pysqltools/src/connection/connection.py | 17 ++++++-
pysqltools/src/log/__init__.py | 2 +-
pysqltools/src/log/log.py | 7 +--
pysqltools/src/sql/__init__.py | 8 ++++
pysqltools/src/{SQL => sql}/constants.py | 0
pysqltools/src/{SQL => sql}/exceptions.py | 0
pysqltools/src/{SQL => sql}/insert.py | 5 +--
pysqltools/src/{SQL => sql}/query.py | 54 ++++++++++++++++++++---
pysqltools/src/{SQL => sql}/table.py | 0
"pysqltools/\302\241" | 0
tests/test_connections.py | 2 +-
tests/test_queries.py | 2 +-
15 files changed, 85 insertions(+), 29 deletions(-)
delete mode 100644 pysqltools/src/SQL/__init__.py
create mode 100644 pysqltools/src/sql/__init__.py
rename pysqltools/src/{SQL => sql}/constants.py (100%)
rename pysqltools/src/{SQL => sql}/exceptions.py (100%)
rename pysqltools/src/{SQL => sql}/insert.py (97%)
rename pysqltools/src/{SQL => sql}/query.py (85%)
rename pysqltools/src/{SQL => sql}/table.py (100%)
create mode 100644 "pysqltools/\302\241"
diff --git a/pysqltools/__init__.py b/pysqltools/__init__.py
index d080443..8b33600 100644
--- a/pysqltools/__init__.py
+++ b/pysqltools/__init__.py
@@ -1,7 +1,7 @@
from .src.connection import SQLConnection
-from .src.SQL.insert import generate_insert_query, insert_pandas
-from .src.SQL.query import Query, SQLString
-from .src.SQL.table import Table
+from .src.sql.insert import generate_insert_query, insert_pandas
+from .src.sql.query import Query, SQLString, get_queries_from_path
+from .src.sql.table import Table
def format_sql(sql: str, **kwargs) -> str:
diff --git a/pysqltools/src/SQL/__init__.py b/pysqltools/src/SQL/__init__.py
deleted file mode 100644
index 8f69434..0000000
--- a/pysqltools/src/SQL/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""
-Queries Package. Contains everything SQL-Text query related
-"""
-
-from pysqltools.src.SQL.query import Query, SQLString
diff --git a/pysqltools/src/__init__.py b/pysqltools/src/__init__.py
index 9cad2bc..74b2fb7 100644
--- a/pysqltools/src/__init__.py
+++ b/pysqltools/src/__init__.py
@@ -2,6 +2,6 @@
Source code for the pysqltools package
"""
-from pysqltools.src.SQL.insert import generate_insert_query, insert_pandas
-from pysqltools.src.SQL.query import Query, SQLString
-from pysqltools.src.SQL.table import Table
+from pysqltools.src.sql.insert import generate_insert_query, insert_pandas
+from pysqltools.src.sql.query import Query, SQLString, get_queries_from_path
+from pysqltools.src.sql.table import Table
diff --git a/pysqltools/src/connection/connection.py b/pysqltools/src/connection/connection.py
index 3cfb2e4..8eae077 100644
--- a/pysqltools/src/connection/connection.py
+++ b/pysqltools/src/connection/connection.py
@@ -4,6 +4,7 @@
import ibm_db
import mysql
import mysql.connector
+import pandas as pd
import pymssql
import pymysql
@@ -13,12 +14,16 @@
from pysqltools.src.connection.exceptions import ConnectionException
from pysqltools.src.log import PabLog
-from pysqltools.src.SQL.query import Query
+from pysqltools.src.sql.query import Query
lg = PabLog("Connections")
class SQLConnection:
+ """
+ Unified connection class for different SQL Dialects.
+ """
+
conn = None
def __init__(
@@ -38,6 +43,9 @@ def __init__(
self.conn = conn
def execute(self, sql: Query) -> None:
+ """
+ Execute a SQL Statement that returns no value
+ """
try:
if isinstance(self.conn, ibm_db.IBM_DBConnection):
ibm_db.exec_immediate(self.conn, sql.sql)
@@ -55,7 +63,10 @@ def execute(self, sql: Query) -> None:
except:
raise ConnectionException
- def fetch(self, sql: Query):
+ def fetch(self, sql: Query, dataframe: bool = False):
+ """
+ Execute a SQL Query object and get the
+ """
try:
if isinstance(self.conn, ibm_db.IBM_DBConnection):
stmt = ibm_db.exec_immediate(self.conn, sql.sql)
@@ -81,6 +92,8 @@ def fetch(self, sql: Query):
while row:
rows.append(row)
row = cursor.fetchone()
+ if dataframe:
+ return pd.DataFrame(rows)
return rows
except Exception as e:
lg.log.error(f"Fetch failed: {e}")
diff --git a/pysqltools/src/log/__init__.py b/pysqltools/src/log/__init__.py
index 3b69233..7413192 100644
--- a/pysqltools/src/log/__init__.py
+++ b/pysqltools/src/log/__init__.py
@@ -1 +1 @@
-from .log import PabLog
+from .log import PabLog, progress_function
diff --git a/pysqltools/src/log/log.py b/pysqltools/src/log/log.py
index 0b252ea..f8bd18f 100644
--- a/pysqltools/src/log/log.py
+++ b/pysqltools/src/log/log.py
@@ -4,7 +4,6 @@
from typing import Any, Callable
import pandas as pd
-import rich
from rich.console import Console
from rich.logging import RichHandler
from rich.markdown import Markdown
@@ -44,6 +43,7 @@ def __init__(
logging.basicConfig(format=__format, handlers=__handlers, level=logging.DEBUG)
self.log = logging.getLogger(log_name)
+ self.log.setLevel(logging.INFO)
self.console = Console()
def add_table(self, df: pd.DataFrame, title: str = "", max_rows: int = 10) -> None:
@@ -88,8 +88,9 @@ def decorator(fun: Callable[..., Any]):
def inner(*args, **kwargs):
"""Wrapped"""
with Progress() as progress:
- task = progress.add_task("f[{color}] {task_name}...", total=total)
- result = fun(progress, task, *args, **kwargs)
+ task = progress.add_task(f"[{color}] {task_name}...", total=total)
+ prog = {"progress": progress, "task": task}
+ result = fun(*args, **kwargs, **prog)
progress.update(task, advance=total)
return result
diff --git a/pysqltools/src/sql/__init__.py b/pysqltools/src/sql/__init__.py
new file mode 100644
index 0000000..0fd5c84
--- /dev/null
+++ b/pysqltools/src/sql/__init__.py
@@ -0,0 +1,8 @@
+"""
+Queries Package. Contains everything SQL-Text query related
+"""
+
+from pysqltools.src.connection import SQLConnection
+from pysqltools.src.sql.insert import generate_insert_query, insert_pandas
+from pysqltools.src.sql.query import Query, SQLString, get_queries_from_path
+from pysqltools.src.sql.table import Table
diff --git a/pysqltools/src/SQL/constants.py b/pysqltools/src/sql/constants.py
similarity index 100%
rename from pysqltools/src/SQL/constants.py
rename to pysqltools/src/sql/constants.py
diff --git a/pysqltools/src/SQL/exceptions.py b/pysqltools/src/sql/exceptions.py
similarity index 100%
rename from pysqltools/src/SQL/exceptions.py
rename to pysqltools/src/sql/exceptions.py
diff --git a/pysqltools/src/SQL/insert.py b/pysqltools/src/sql/insert.py
similarity index 97%
rename from pysqltools/src/SQL/insert.py
rename to pysqltools/src/sql/insert.py
index 2e4b330..1d3dc95 100644
--- a/pysqltools/src/SQL/insert.py
+++ b/pysqltools/src/sql/insert.py
@@ -1,4 +1,3 @@
-import time
from datetime import date
from typing import Any, Generator
@@ -8,7 +7,7 @@
from pysqltools.src.connection import SQLConnection
from pysqltools.src.log import PabLog
-from pysqltools.src.SQL.query import Query, assign_parameter
+from pysqltools.src.sql.query import Query
lg = PabLog("Insert")
@@ -123,7 +122,7 @@ def insert_pandas(
task1 = progress.add_task("[red]Generating Queries...", total=1000)
task2 = progress.add_task("[green]Inserting Data...", total=iterations)
task3 = progress.add_task("[cyan]Finishing...", total=1000)
- for i in range(1000):
+ for _ in range(1000):
progress.update(task1, advance=1.0)
for query in generate_insert_query(df, table, schema, batch_size):
connection.execute(query)
diff --git a/pysqltools/src/SQL/query.py b/pysqltools/src/sql/query.py
similarity index 85%
rename from pysqltools/src/SQL/query.py
rename to pysqltools/src/sql/query.py
index 29f0339..92a4f33 100644
--- a/pysqltools/src/SQL/query.py
+++ b/pysqltools/src/sql/query.py
@@ -4,16 +4,26 @@
"""
import datetime
+import os
import re
from typing import Any, Generator, Union
+import rich
+import rich.progress
import sqlparse
from multimethod import multimethod
-from pysqltools.src.SQL.exceptions import QueryFormattingError
+from pysqltools.src.log import PabLog, progress_function
+from pysqltools.src.sql.exceptions import QueryFormattingError
+
+lg = PabLog("Query")
class QueryException(Exception):
+ """
+ Exception during Query processing.
+ """
+
def __init__(self, *args: object) -> None:
super().__init__(*args)
@@ -109,10 +119,15 @@ class Query:
`query = Query(sql = sql).format(table_param = "MyTable")`
"""
- def __init__(self, sql: str, *args, **kwargs) -> None:
+ def __init__(self, sql: str, **kwargs) -> None:
self._sql = sql.lower()
self.parsed = sqlparse.parse(sql)[0]
self.options = kwargs
+ self._parameters = None
+ self._tables = None
+ self._selects = None
+ self._windows = None
+ self._ctes = None
@property
def sql(self):
@@ -130,9 +145,7 @@ def sql(self):
reindent=True,
)
)
- return self._sql
- else:
- return self._sql
+ return self._sql
@sql.setter
def sql(self, sql: str):
@@ -249,8 +262,8 @@ def format(self, **kwargs) -> "Query":
for k, v in kwargs.items():
self.sql = self.sql.replace("{{" + k + "}}", assign_parameter(v))
self.parsed = sqlparse.parse(sql=self.sql)[0]
- except:
- raise QueryFormattingError
+ except Exception as e:
+ raise QueryFormattingError(e)
return self
@@ -284,3 +297,30 @@ def __dict__(self):
"ctes": self.get_ctes_dict(),
"parameters": list(self.parameters),
}
+
+
+@progress_function("searching queries...", color="green")
+def get_queries_from_path(path: str = None, *args, **kwargs) -> list[Query]:
+ # lg.add_md("## Scanning Directory for SQL Queries")
+ queries = {}
+ for dirpath, dirname, filename in os.walk(path):
+ lg.log.info(f"Searching {dirpath}...")
+ kwargs["progress"].advance(kwargs["task"], 1)
+ i = 2
+ for f in filename:
+ if f.__contains__(".sql"):
+ lg.log.info(f"Added item '{f}'")
+ if f in queries:
+ name = f.replace(".sql", "") + f"_{i}"
+ i += 1
+ else:
+ name = f.replace(".sql", "")
+ queries.update(
+ {
+ name: Query(
+ open(os.path.join(dirpath, f), "r", encoding="utf-8").read()
+ )
+ }
+ )
+
+ return queries
diff --git a/pysqltools/src/SQL/table.py b/pysqltools/src/sql/table.py
similarity index 100%
rename from pysqltools/src/SQL/table.py
rename to pysqltools/src/sql/table.py
diff --git "a/pysqltools/\302\241" "b/pysqltools/\302\241"
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_connections.py b/tests/test_connections.py
index 8b45fdf..3d06263 100644
--- a/tests/test_connections.py
+++ b/tests/test_connections.py
@@ -7,7 +7,7 @@
import pandas as pd
from pysqltools.src.connection.connection import SQLConnection
-from pysqltools.src.SQL.insert import insert_pandas
+from pysqltools.src.sql.insert import insert_pandas
df = pd.DataFrame(
{
diff --git a/tests/test_queries.py b/tests/test_queries.py
index b4c346e..9c95c05 100644
--- a/tests/test_queries.py
+++ b/tests/test_queries.py
@@ -4,7 +4,7 @@
import sqlparse
from pysqltools.src import Query, generate_insert_query
-from pysqltools.src.SQL.table import Table
+from pysqltools.src.sql.table import Table
def test_ctes():