From 53be34ab67904384e85a7a113b757e2fbf9a8ad1 Mon Sep 17 00:00:00 2001 From: Pablo Minue Date: Sun, 7 Jul 2024 15:44:27 +0200 Subject: [PATCH] Added get_queries_from_path function, applied some linting to the code --- pysqltools/__init__.py | 6 +-- pysqltools/src/SQL/__init__.py | 5 --- pysqltools/src/__init__.py | 6 +-- pysqltools/src/connection/connection.py | 17 ++++++- pysqltools/src/log/__init__.py | 2 +- pysqltools/src/log/log.py | 7 +-- pysqltools/src/sql/__init__.py | 8 ++++ pysqltools/src/{SQL => sql}/constants.py | 0 pysqltools/src/{SQL => sql}/exceptions.py | 0 pysqltools/src/{SQL => sql}/insert.py | 5 +-- pysqltools/src/{SQL => sql}/query.py | 54 ++++++++++++++++++++--- pysqltools/src/{SQL => sql}/table.py | 0 "pysqltools/\302\241" | 0 tests/test_connections.py | 2 +- tests/test_queries.py | 2 +- 15 files changed, 85 insertions(+), 29 deletions(-) delete mode 100644 pysqltools/src/SQL/__init__.py create mode 100644 pysqltools/src/sql/__init__.py rename pysqltools/src/{SQL => sql}/constants.py (100%) rename pysqltools/src/{SQL => sql}/exceptions.py (100%) rename pysqltools/src/{SQL => sql}/insert.py (97%) rename pysqltools/src/{SQL => sql}/query.py (85%) rename pysqltools/src/{SQL => sql}/table.py (100%) create mode 100644 "pysqltools/\302\241" diff --git a/pysqltools/__init__.py b/pysqltools/__init__.py index d080443..8b33600 100644 --- a/pysqltools/__init__.py +++ b/pysqltools/__init__.py @@ -1,7 +1,7 @@ from .src.connection import SQLConnection -from .src.SQL.insert import generate_insert_query, insert_pandas -from .src.SQL.query import Query, SQLString -from .src.SQL.table import Table +from .src.sql.insert import generate_insert_query, insert_pandas +from .src.sql.query import Query, SQLString, get_queries_from_path +from .src.sql.table import Table def format_sql(sql: str, **kwargs) -> str: diff --git a/pysqltools/src/SQL/__init__.py b/pysqltools/src/SQL/__init__.py deleted file mode 100644 index 8f69434..0000000 --- a/pysqltools/src/SQL/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Queries Package. Contains everything SQL-Text query related -""" - -from pysqltools.src.SQL.query import Query, SQLString diff --git a/pysqltools/src/__init__.py b/pysqltools/src/__init__.py index 9cad2bc..74b2fb7 100644 --- a/pysqltools/src/__init__.py +++ b/pysqltools/src/__init__.py @@ -2,6 +2,6 @@ Source code for the pysqltools package """ -from pysqltools.src.SQL.insert import generate_insert_query, insert_pandas -from pysqltools.src.SQL.query import Query, SQLString -from pysqltools.src.SQL.table import Table +from pysqltools.src.sql.insert import generate_insert_query, insert_pandas +from pysqltools.src.sql.query import Query, SQLString, get_queries_from_path +from pysqltools.src.sql.table import Table diff --git a/pysqltools/src/connection/connection.py b/pysqltools/src/connection/connection.py index 3cfb2e4..8eae077 100644 --- a/pysqltools/src/connection/connection.py +++ b/pysqltools/src/connection/connection.py @@ -4,6 +4,7 @@ import ibm_db import mysql import mysql.connector +import pandas as pd import pymssql import pymysql @@ -13,12 +14,16 @@ from pysqltools.src.connection.exceptions import ConnectionException from pysqltools.src.log import PabLog -from pysqltools.src.SQL.query import Query +from pysqltools.src.sql.query import Query lg = PabLog("Connections") class SQLConnection: + """ + Unified connection class for different SQL Dialects. + """ + conn = None def __init__( @@ -38,6 +43,9 @@ def __init__( self.conn = conn def execute(self, sql: Query) -> None: + """ + Execute a SQL Statement that returns no value + """ try: if isinstance(self.conn, ibm_db.IBM_DBConnection): ibm_db.exec_immediate(self.conn, sql.sql) @@ -55,7 +63,10 @@ def execute(self, sql: Query) -> None: except: raise ConnectionException - def fetch(self, sql: Query): + def fetch(self, sql: Query, dataframe: bool = False): + """ + Execute a SQL Query object and get the + """ try: if isinstance(self.conn, ibm_db.IBM_DBConnection): stmt = ibm_db.exec_immediate(self.conn, sql.sql) @@ -81,6 +92,8 @@ def fetch(self, sql: Query): while row: rows.append(row) row = cursor.fetchone() + if dataframe: + return pd.DataFrame(rows) return rows except Exception as e: lg.log.error(f"Fetch failed: {e}") diff --git a/pysqltools/src/log/__init__.py b/pysqltools/src/log/__init__.py index 3b69233..7413192 100644 --- a/pysqltools/src/log/__init__.py +++ b/pysqltools/src/log/__init__.py @@ -1 +1 @@ -from .log import PabLog +from .log import PabLog, progress_function diff --git a/pysqltools/src/log/log.py b/pysqltools/src/log/log.py index 0b252ea..f8bd18f 100644 --- a/pysqltools/src/log/log.py +++ b/pysqltools/src/log/log.py @@ -4,7 +4,6 @@ from typing import Any, Callable import pandas as pd -import rich from rich.console import Console from rich.logging import RichHandler from rich.markdown import Markdown @@ -44,6 +43,7 @@ def __init__( logging.basicConfig(format=__format, handlers=__handlers, level=logging.DEBUG) self.log = logging.getLogger(log_name) + self.log.setLevel(logging.INFO) self.console = Console() def add_table(self, df: pd.DataFrame, title: str = "", max_rows: int = 10) -> None: @@ -88,8 +88,9 @@ def decorator(fun: Callable[..., Any]): def inner(*args, **kwargs): """Wrapped""" with Progress() as progress: - task = progress.add_task("f[{color}] {task_name}...", total=total) - result = fun(progress, task, *args, **kwargs) + task = progress.add_task(f"[{color}] {task_name}...", total=total) + prog = {"progress": progress, "task": task} + result = fun(*args, **kwargs, **prog) progress.update(task, advance=total) return result diff --git a/pysqltools/src/sql/__init__.py b/pysqltools/src/sql/__init__.py new file mode 100644 index 0000000..0fd5c84 --- /dev/null +++ b/pysqltools/src/sql/__init__.py @@ -0,0 +1,8 @@ +""" +Queries Package. Contains everything SQL-Text query related +""" + +from pysqltools.src.connection import SQLConnection +from pysqltools.src.sql.insert import generate_insert_query, insert_pandas +from pysqltools.src.sql.query import Query, SQLString, get_queries_from_path +from pysqltools.src.sql.table import Table diff --git a/pysqltools/src/SQL/constants.py b/pysqltools/src/sql/constants.py similarity index 100% rename from pysqltools/src/SQL/constants.py rename to pysqltools/src/sql/constants.py diff --git a/pysqltools/src/SQL/exceptions.py b/pysqltools/src/sql/exceptions.py similarity index 100% rename from pysqltools/src/SQL/exceptions.py rename to pysqltools/src/sql/exceptions.py diff --git a/pysqltools/src/SQL/insert.py b/pysqltools/src/sql/insert.py similarity index 97% rename from pysqltools/src/SQL/insert.py rename to pysqltools/src/sql/insert.py index 2e4b330..1d3dc95 100644 --- a/pysqltools/src/SQL/insert.py +++ b/pysqltools/src/sql/insert.py @@ -1,4 +1,3 @@ -import time from datetime import date from typing import Any, Generator @@ -8,7 +7,7 @@ from pysqltools.src.connection import SQLConnection from pysqltools.src.log import PabLog -from pysqltools.src.SQL.query import Query, assign_parameter +from pysqltools.src.sql.query import Query lg = PabLog("Insert") @@ -123,7 +122,7 @@ def insert_pandas( task1 = progress.add_task("[red]Generating Queries...", total=1000) task2 = progress.add_task("[green]Inserting Data...", total=iterations) task3 = progress.add_task("[cyan]Finishing...", total=1000) - for i in range(1000): + for _ in range(1000): progress.update(task1, advance=1.0) for query in generate_insert_query(df, table, schema, batch_size): connection.execute(query) diff --git a/pysqltools/src/SQL/query.py b/pysqltools/src/sql/query.py similarity index 85% rename from pysqltools/src/SQL/query.py rename to pysqltools/src/sql/query.py index 29f0339..92a4f33 100644 --- a/pysqltools/src/SQL/query.py +++ b/pysqltools/src/sql/query.py @@ -4,16 +4,26 @@ """ import datetime +import os import re from typing import Any, Generator, Union +import rich +import rich.progress import sqlparse from multimethod import multimethod -from pysqltools.src.SQL.exceptions import QueryFormattingError +from pysqltools.src.log import PabLog, progress_function +from pysqltools.src.sql.exceptions import QueryFormattingError + +lg = PabLog("Query") class QueryException(Exception): + """ + Exception during Query processing. + """ + def __init__(self, *args: object) -> None: super().__init__(*args) @@ -109,10 +119,15 @@ class Query: `query = Query(sql = sql).format(table_param = "MyTable")` """ - def __init__(self, sql: str, *args, **kwargs) -> None: + def __init__(self, sql: str, **kwargs) -> None: self._sql = sql.lower() self.parsed = sqlparse.parse(sql)[0] self.options = kwargs + self._parameters = None + self._tables = None + self._selects = None + self._windows = None + self._ctes = None @property def sql(self): @@ -130,9 +145,7 @@ def sql(self): reindent=True, ) ) - return self._sql - else: - return self._sql + return self._sql @sql.setter def sql(self, sql: str): @@ -249,8 +262,8 @@ def format(self, **kwargs) -> "Query": for k, v in kwargs.items(): self.sql = self.sql.replace("{{" + k + "}}", assign_parameter(v)) self.parsed = sqlparse.parse(sql=self.sql)[0] - except: - raise QueryFormattingError + except Exception as e: + raise QueryFormattingError(e) return self @@ -284,3 +297,30 @@ def __dict__(self): "ctes": self.get_ctes_dict(), "parameters": list(self.parameters), } + + +@progress_function("searching queries...", color="green") +def get_queries_from_path(path: str = None, *args, **kwargs) -> list[Query]: + # lg.add_md("## Scanning Directory for SQL Queries") + queries = {} + for dirpath, dirname, filename in os.walk(path): + lg.log.info(f"Searching {dirpath}...") + kwargs["progress"].advance(kwargs["task"], 1) + i = 2 + for f in filename: + if f.__contains__(".sql"): + lg.log.info(f"Added item '{f}'") + if f in queries: + name = f.replace(".sql", "") + f"_{i}" + i += 1 + else: + name = f.replace(".sql", "") + queries.update( + { + name: Query( + open(os.path.join(dirpath, f), "r", encoding="utf-8").read() + ) + } + ) + + return queries diff --git a/pysqltools/src/SQL/table.py b/pysqltools/src/sql/table.py similarity index 100% rename from pysqltools/src/SQL/table.py rename to pysqltools/src/sql/table.py diff --git "a/pysqltools/\302\241" "b/pysqltools/\302\241" new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_connections.py b/tests/test_connections.py index 8b45fdf..3d06263 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -7,7 +7,7 @@ import pandas as pd from pysqltools.src.connection.connection import SQLConnection -from pysqltools.src.SQL.insert import insert_pandas +from pysqltools.src.sql.insert import insert_pandas df = pd.DataFrame( { diff --git a/tests/test_queries.py b/tests/test_queries.py index b4c346e..9c95c05 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -4,7 +4,7 @@ import sqlparse from pysqltools.src import Query, generate_insert_query -from pysqltools.src.SQL.table import Table +from pysqltools.src.sql.table import Table def test_ctes():