ChIr7>X$2m=3_p198kK%S(8DV#`H0}NOS`b|LXRR_IP>o0~y9q;J;qH;O14~
zD$u}V>I=9zR#UYSp*}D2#RElX@z!jV`EI)JA9xvbXi;D|O-i8~#LhEtSZ(z(p#KMn
zVn}Xv_Px0ov*?80+FmpfYqTU8pI$ffb(RS`R4kF`^S-upgp8d?aoMd*p+rar@a7Fb
zHpW?{&=aYjXtB)PQ?(XX_V%*cdm8&+8&jr!wQy1fO$JGIH&C00cno{xg592|FfgN{%w@p*DK9Nz9r0#lV
z(zM|MSC^)72fy*fXXI+c;k7a;&w`;>n)&!Hbh>B4U@h#Whz6g8C18IIDG<7@>Y%js
zma5vxplSQ2-A82;MkbGv3RtIXZhb;pj=fW3l!(TLmd=k?xbWH2xyJ~mJRMg(leQu%
z>_aoKP9^PvVZ8KvFlLT7T}pG0@Zc7)yn3oSP4X4!{{8XLfE3lEQR41T6-QO|2@{3$
zO2cLnaPc_6WARNSyPFsF=$k?(y%0%FDrfj0T#A^_qk|P;ih(L9)EytGeaV^s(JI3F
z8^83HAu+`Ti3t%^CWcnbpVQxx&z7uO8LqRdb$oxIDy0_Rq0fO2WwP8uuQKF{%ovli
zp&w~OJbk|hB~L{P&GDY$5u(QyV9tA1oSV;B>CLLr(O2i9CFbbFf6DP0Jrg2o5&7o*
z3-l$62`(Dgp17Qnz_KlIlmJl5WV8geMP}K8%o9gY$Y;ZYG!vQRntz`#;^Hrs**Hs*
z((_2QlD6j#kFAjj5a3C*vq_4N+N%=Tmq
zV^&`Ox6Logr<#ec?d!C#5(nI_n^NjI-;YYx4aD`ZfNtg?HYgzkr@=~JQHa7}ofr!G
zMWuJA2NiDv_4C}(m5$`L`8o96Tq`ygryGM*`PuM091aV|)FTYUVb#1>ak(C-&nUQ;
zxKta@r(XV9fH31LWfHp&$1;QY8qA}!sE8tQjmoKWo(*RDsCgJn__w^~>1U2@MGpy(x(NI|Qi^^u!qE@(gkC4XEo+hTQO7fYd|~S|1?(G?2X|C2@wc%TQjfo78#Dvm3vT*@0_n9eT{u|{3*!g-^4pJqLSj+Jyba9N03IRoIF_ji
z?c-0$^$D(}9o@(iA?oIz=p&-$|0j8yH@;7|ov^t246SD1&n(IsLo9*M_FrC{XZZhEXde2zrvu^sEl`knHy
zsuMa~B*ZEv4(bd_wme};Rpq>SDTWv=!VzINbD3kq3nvXIOlpYsnank7#j!!RUE>p9Q^ocmB
z?K6?IZ>z?(T6GWztU7lpeW%hP7=0qS{@dt2Ax2tfB!}vAVM>ir8G_vD)W6>RdwnaU(S&+g{5t9bG^g`Ir_ocfodBiEa=;(D}m1c^)$
zc_nVLt*6I8KAA&0fHFB)x0K;RJiyL2(D
z8tvC~VEyAB{Y`dA22Ts0#D%#Vwa)Dz>O=vmoQZ&kHf+K1HeAvKZIleydiuS(N5O-_OwEaPF*
zTq<0SKHf)fWqhZ;neW0s<%;!T%GG&_M4Ob5yk8?Mdi9q>jwxyJPthxmOs_zGC-PoL
z-1GIrPjDo6oayAF^##+xSUxJZ>?)tZgYETuH7wvZla9U{lvOFQ1aCQIYw-yd?1=u7
z<^ykY>Jv6Ssl5YxTb+a=-)zj+IgC{DkS05QNZ?C}`>}`T|>6$s$tK`YaCwe>eW{
z`;`8)dWifn5K8I9clJhkMS0aGejBP>C;&!=AnS&^Sr#GejjTfV``_vG9yq_BjYG^j
zZt(LX_|%Z=+hIpM
zj*nT}T0B2{R`g(-`YL2ZL9RTY9`b;77Za$)hYTDvWNjg+L7hZ~-uop4t^QkL5;Ufm
zg{oNB>b`THy`E36kb|FNt(2;3D_Q9A7%?7OKX*;fPd%sxHC*389yPdVD$LF;>9Ke0
zxBv4ihrNiUt-ZQjG%*X#K-RcOpPXw*;Pmb7!(*ew6xaN1%11}XQwrmsaLvEBmp;K6
z?!76_C}n9`x?h)&YiPBidt(w)e$R{AX9B(g-2h@D&H!CJWBJCWM0)y8N)onSP}<0g
zCq5bB$dUMhbQwl391$$qBCR`suY1-Yk<}^Gln3)42j%U7A<>(t^?_(in8%yMdnF6Y
zrE_ULj(nTsYin#-gT*Q_KKfclZWs%qQSuOq12JuVT_)ne!ysJ+37G03o#cS_J5*Qg
zw5|mLDU)%axqX>ki0ju+XFNFzyySVzAtMwnbtBP`nttf5W`uIvSTlGXD|~oN$!5oo
z;YK!{EBT>^snQoQUb}!#!VWO~01?>8s>adjTmRF(6d^}e#IJA=~09EDW
z1!2}q0Kf0_dit`&C!UBim)`|>-XjNAP)x%&{yZbN9LzFcXtrOW+rmg^_j>y^ub3N&
zLt6PV8Gh@)F!^=l$*|1MbxLY>LdzyVb!qB8sv<^j?;gsHGHyqQh497q(Ce`xO7Gas
zY_9c=wK{&bmzDZ)A&@ySyC*LA2~P1E2+lSG5-$=|6fblO$A%vqe6Oc=P^|_jR3?@7
zTG>&PBd?xOCnKf=NS75gv!gVfl2qiP;!#iXEd7_&4dj}2Z0-^}K|?xXIEuFW(#wj9
zlBp+1E{&;XDTynt(WgnuH+GPC+MGEf8FL@wQV30KrHqww?o}g2+~)YuR>f}(v6Lz1
z*#S3b^3{NvT!m@UXq(X+e@EIWBt`>Yw-$!&1w)&7YS>u#8`HE4WujAVL%!vF-d{H!
zG~%VMF5z3CB>o7OEm$iDy7Wu`otPQC)H>%>T@Tx4dPIIDew6J40{8oFYV#MrQV!M8
zKNZ6KCdni}{c&^BJo0iM}AA3R>aBHdMjOQ%Dl?Y~Z+Qt-AGXq9|9mJn#b65*eXG1~n0-6yzi
zsm&P_eipt=?G|3Qb+Mxx*TOe6mQN2xD6eoDS{LJ$N|3Ck%+6?~7H^a902i3JQV_2F
ztYgU8OYku>iGqpeFE=$nxohXu(cw!Kn%h)<9a!Y8h3?wFn~o)sH6Z$CeN>tflY!Ki
z{(K`Vw6ojBmAm}p_|z6Q4cyLhM$4tZZ5~(q>uFQ~(Jz457edxAy?z$x0ylQ+ZbR!q
zz$i#Zr}%@>Tc|+K4S6t_R^1{61pJW&c>#_Eqebp3>W
zlj=!ky?US`-f{{i-&4mZp@b#(7t8GZvEF8-emwr0HDyLszKx;Q;-?qf81_KS@=QQI4#Pj;RN4d8g0mY1@qsQYmpV1HxLJr@?IziRsZD-3S%
z7;H3&I_0=44Oy+R58Js+hknl-5&O(GDhSYP3+1@$0y`UE&G3`8Og6%}(gEmd>g{7(s_nG2)iC#wCjZE*NuTY)iGCo!Ol0%7c
zzg)8fD^l<{>-2p4$t(;CbyHPUTX*4S{i{FQDZkEZrl_60#o2^&0QtN~-Z6DRysL%2
z1NuPOJxS341yz(>I7jtA(%Vvyb?zfU>8n5Q?0n>08iRC8#y4M${|vV9yq7I$KlIY-
zDJ5&_Rr4032@qwVjPxHQ<{YL$`^URFBX#cXg{8?&!jD;blXk{JaUM@>_PN1D5t7nJ
z45OZJq*Jy>JDmqSS?9MQT+U$ViL5?--5Av5^J
zv?|*hXgTC=J9n1hhzRHim$lx0S3wRT_Uk!>4tBvo@QUzqe4WL=EC;l2Fh5i(|F#As
zQT~|Ypx|yw_a^JiP^*c<^%5JLe$icWYp2yO2v0|H_bo_Xat3vt@hT0xA8oV?j5(7N
z^3TwMtTIWWG@8}a_XtOP+ng7x
zYSt>8t=wmGHq`PWB+<|3I-QRvmQzwMfcCU(C~{)hzAezX{!l&b-Al
zW~^7<+`C-7^(AeNg$+%cLG8UA%dVB!i1H9;W__oD<2a1|hK_w1(`d8bE$<&X;jc8#
zF!+%4Q#L52Cl&fGV)smpJXt2qEyNO9gxI7C7OJvy{?~w%1(if{W4#bej;VZO5OyzY
zc9ff@ja!-+d-mpEoB?$m$B9Tt{JmNyn(iC(HOi;P*w2zA8r{K|ybUQpp#E_=4Q7<3
zv2Efxcz#3~X#SgK+c{s=Bjwt$2^!^4~u;
zfV-IF>!|?djxlAsgQ#lI;rseXL!44jV(y51qC9A}DK-T==JhX0hKo$}hLgfbzbU)NYVd!F8XDf&^d?5wigl4jURrlE%I&HnXm25na8o1Q}v25rEBIzwcQ None:
self.conn.commit()
except:
lg.log.warning("Connection can not be commited")
- except:
- raise ConnectionException
+ except Exception as e:
+ raise ConnectionException(e)
def fetch(self, sql: Query, dataframe: bool = False):
"""
@@ -84,7 +84,11 @@ def fetch(self, sql: Query, dataframe: bool = False):
else:
cursor = self.conn.cursor()
cursor.execute(sql.sql)
+ if hasattr(cursor, "description"):
+ columns = [i[0] for i in cursor.description]
if hasattr(cursor, "fetchall"):
+ if dataframe:
+ return pd.DataFrame(cursor.fetchall(), columns=columns)
return cursor.fetchall()
else:
rows = []
diff --git a/pysqltools/src/sql/constants.py b/pysqltools/src/sql/constants.py
index 7e5409d..807f1c4 100644
--- a/pysqltools/src/sql/constants.py
+++ b/pysqltools/src/sql/constants.py
@@ -10,4 +10,5 @@
"float64": "double",
"bool": "bool",
"datetime64": "timestamp",
+ "datetime64[ns]": "timestamp",
}
diff --git a/pysqltools/src/sql/insert.py b/pysqltools/src/sql/insert.py
index 1d3dc95..e58126c 100644
--- a/pysqltools/src/sql/insert.py
+++ b/pysqltools/src/sql/insert.py
@@ -12,59 +12,159 @@
lg = PabLog("Insert")
-def prepare_value(val: Any) -> Any:
+def prepare_value(val: Any, dialect: str) -> Any:
"""
Format value from Python types to SQL Types
"""
- if isinstance(val, bool):
+ if dialect.lower().__contains__("trino"):
+ if isinstance(val, str):
+ return f"'{val}'"
+ if isinstance(val, bool):
+ return val
+ if isinstance(val, dict):
+ val = str(val).replace("'", '"')
+ if isinstance(val, pd.Timestamp):
+ val = "TIMESTAMP '" + str(val) + "'"
+ if isinstance(val, date):
+ val = "DATE '" + str(val) + "'"
+ if isinstance(val, list):
+ val = "ARRAY " + str(val)
+ if isinstance(val, float):
+ val = "DOUBLE '" + str(val) + "'"
+ if isinstance(val, int):
+ val = "INT '" + str(val) + "'"
+ if pd.isnull(val):
+ val = "NULL"
+
+ try:
+ if (
+ "'" in val
+ and "DOUBLE" not in val
+ and "INT" not in val
+ and "TIMESTAMP" not in val
+ and "DATE" not in val
+ ):
+ val = val.replace("'", "''")
+ except TypeError:
+ lg.log.warning("Not Adding Quotes")
+
return val
- if isinstance(val, dict):
- val = str(val).replace("'", '"')
- if isinstance(val, pd.Timestamp):
- val = "TIMESTAMP '" + str(val) + "'"
- if isinstance(val, date):
- val = "DATE '" + str(val) + "'"
- if isinstance(val, list):
- val = "ARRAY " + str(val)
- if isinstance(val, float):
- val = "DOUBLE '" + str(val) + "'"
- if isinstance(val, int):
- val = "INT '" + str(val) + "'"
- if pd.isnull(val):
- val = "NULL"
-
- try:
- if (
- "'" in val
- and "DOUBLE" not in val
- and "INT" not in val
- and "TIMESTAMP" not in val
- and "DATE" not in val
- ):
- val = val.replace("'", "''")
- except TypeError:
- lg.log.warning("Not Adding Quotes")
+ if dialect.lower().__contains__("mysql"):
+ if isinstance(val, str):
+ return f"'{val}'"
+ if isinstance(val, bool):
+ return bool(val)
+ if isinstance(val, dict):
+ val = str(val).replace("'", '"')
+ if isinstance(val, pd.Timestamp):
+ val = f"'{val}'"
+ if isinstance(val, date):
+ val = f"'{val}'"
+ if isinstance(val, list):
+ val = f"'{str(val)}'"
+ if isinstance(val, float):
+ val = str(val)
+ if isinstance(val, int):
+ val = str(val)
+ if pd.isnull(val):
+ val = "NULL"
+
+ return val
+
+ if dialect.lower().__contains__("sqlite"):
+ if isinstance(val, str):
+ return f"'{val}'"
+ if isinstance(val, bool):
+ return bool(val)
+ if isinstance(val, dict):
+ val = str(val).replace("'", '"')
+ if isinstance(val, pd.Timestamp):
+ val = f"'{val}'"
+ if isinstance(val, date):
+ val = f"'{val}'"
+ if isinstance(val, list):
+ val = f"'{str(val)}'"
+ if isinstance(val, float):
+ val = str(val)
+ if isinstance(val, int):
+ val = str(val)
+ if pd.isnull(val):
+ val = "NULL"
- return val
+ return val
+ if dialect.lower().__contains__("ibm"):
+ if isinstance(val, str):
+ return f"'{val}'"
+ if isinstance(val, bool):
+ return bool(val)
+ if isinstance(val, dict):
+ val = str(val).replace("'", '"')
+ if isinstance(val, pd.Timestamp):
+ val = f"TIMESTAMP '{val}'"
+ if isinstance(val, date):
+ val = f"DATE '{val}'"
+ if isinstance(val, list):
+ val = f"'{str(val)}'"
+ if isinstance(val, float):
+ val = str(val)
+ if isinstance(val, int):
+ val = str(val)
+ if pd.isnull(val):
+ val = "NULL"
-def join_values(data: list[Any]) -> str:
+ return val
+
+ if dialect.lower().__contains__("sqlserver"):
+ if isinstance(val, str):
+ return f"'{val}'"
+ if isinstance(val, bool):
+ return bool(val)
+ if isinstance(val, dict):
+ val = str(val).replace("'", '"')
+ if isinstance(val, pd.Timestamp):
+ val = f"'{val}'"
+ if isinstance(val, date):
+ val = f"'{val}'"
+ if isinstance(val, list):
+ val = f"'{str(val)}'"
+ if isinstance(val, float):
+ val = str(val)
+ if isinstance(val, int):
+ val = str(val)
+ if pd.isnull(val):
+ val = "NULL"
+
+ return val
+
+ if dialect.lower().__contains__("mariadb"):
+ if isinstance(val, str):
+ return f"'{val}'"
+ if isinstance(val, bool):
+ return bool(val)
+ if isinstance(val, dict):
+ val = str(val).replace("'", '"')
+ if isinstance(val, pd.Timestamp):
+ val = f"'{val}'"
+ if isinstance(val, date):
+ val = f"'{val}'"
+ if isinstance(val, list):
+ val = f"'{str(val)}'"
+ if isinstance(val, float):
+ val = str(val)
+ if isinstance(val, int):
+ val = str(val)
+ if pd.isnull(val):
+ val = "NULL"
+ return val
+
+
+def join_values(data: list[Any], dialect: str) -> str:
"""
Create a String for the VALUES () SQL Syntax
"""
clean_list = []
for val in data:
- if isinstance(val, bool):
- val = str(val)
- if (
- isinstance(val, str)
- and "DOUBLE" not in val
- and "INT" not in val
- and "TIMESTAMP" not in val
- and "DATE" not in val
- and "ARRAY" not in val
- ) and val.lower() not in ["true", "false"]:
- val = "'" + val + "'"
try:
if "NULL" in val:
val = "NULL"
@@ -76,18 +176,22 @@ def join_values(data: list[Any]) -> str:
return "(" + str_data + ")"
-def pandas_to_sql(df: pd.DataFrame) -> Generator[str, None, None]:
+def pandas_to_sql(df: pd.DataFrame, dialect: str) -> Generator[str, None, None]:
"""
Generator to get one row insert statement
"""
for row in df.values:
- data_list = [prepare_value(x) for x in row]
- data_string = join_values(data_list)
+ data_list = [prepare_value(x, dialect=dialect) for x in row]
+ data_string = join_values(data_list, dialect=dialect)
yield data_string
def generate_insert_query(
- df: pd.DataFrame, table: str = None, schema: str = None, batch_size: int = 5000
+ df: pd.DataFrame,
+ table: str = None,
+ schema: str = None,
+ batch_size: int = 5000,
+ dialect: str = "trino",
) -> Generator[Query, None, None]:
if df.empty:
raise TypeError("DataFrame can not be empty")
@@ -96,7 +200,7 @@ def generate_insert_query(
percentage = round(100 * previous_iter / len(df), 2)
lg.log.info("Generating Insert Queries... %s", percentage)
batch = df.iloc[previous_iter : previous_iter + batch_size]
- data_points = list(pandas_to_sql(batch))
+ data_points = list(pandas_to_sql(batch, dialect))
data_points_string = ",".join(data_points)
if schema and table:
table = f"{schema}.{table}"
@@ -111,21 +215,29 @@ def generate_insert_query(
def insert_pandas(
df: pd.DataFrame,
connection: SQLConnection,
- table: str,
- schema: str,
batch_size: int,
+ table: str,
+ schema: str = "",
+ dialect: str = "trino",
):
if not table and schema:
raise TypeError("Table and Schema need to be provided")
with Progress() as progress:
iterations = len(df) / batch_size
- task1 = progress.add_task("[red]Generating Queries...", total=1000)
- task2 = progress.add_task("[green]Inserting Data...", total=iterations)
- task3 = progress.add_task("[cyan]Finishing...", total=1000)
+ task1 = progress.add_task("[red] Generating Queries...", total=1000)
+ task2 = progress.add_task("[green] Inserting Data...", total=iterations)
+ task3 = progress.add_task("[cyan] Finishing...", total=1000)
for _ in range(1000):
progress.update(task1, advance=1.0)
- for query in generate_insert_query(df, table, schema, batch_size):
- connection.execute(query)
+ for query in generate_insert_query(
+ df, table, schema, batch_size, dialect=dialect
+ ):
+ try:
+ connection.execute(query)
+ except Exception as e:
+ lg.log.warning("Query Execution Failed")
+ lg.log.error(e)
+ print(query.sql)
progress.update(task2, advance=1)
for i in range(1000):
progress.update(task3, advance=1.0)
diff --git a/pysqltools/src/sql/table.py b/pysqltools/src/sql/table.py
index 21c1386..0c7cefc 100644
--- a/pysqltools/src/sql/table.py
+++ b/pysqltools/src/sql/table.py
@@ -7,6 +7,8 @@
import pandas as pd
import sqlparse
+from pysqltools.src.sql.query import Query
+
from .constants import TYPE_MAPPING
from .insert import insert_pandas
@@ -22,7 +24,11 @@ def __init__(self, table: str, schema: Union[str, None] = None) -> None:
self.table = f"{schema}.{table}"
def create_from_df(
- self, df: pd.DataFrame, insert_data: bool = False, **insert_kwargs: Any
+ self,
+ df: pd.DataFrame,
+ execute: bool = False,
+ insert_data: bool = False,
+ **insert_kwargs: Any,
) -> str:
"""
Get the SQL statement to create a SQL table based on a Pandas DataFrame. If the insert_data argument is set to True,
@@ -41,12 +47,14 @@ def create_from_df(
),
)
)
- sql = f"CREATE TABLE {self.table} ( "
+ sql = f"CREATE TABLE IF NOT EXISTS {self.table} ( "
for k, v in columns.items():
sql += f"{k} {v}, "
sql = sql[:-2] + " )"
if not insert_data:
return sqlparse.format(sql, encoding="utf-8")
+ if execute:
+ insert_kwargs["connection"].execute(Query(sql))
if "batch_size" in insert_kwargs:
batch_size = insert_kwargs["batch_size"]
else:
@@ -57,8 +65,10 @@ def create_from_df(
connection=insert_kwargs["connection"],
table=self.table,
batch_size=batch_size,
+ dialect=insert_kwargs["dialect"],
)
- except TypeError:
+ except TypeError as e:
raise TypeError(
- "Please include the insert arguments into the create_table_from_df method"
+ "Please include the insert arguments into the create_table_from_df method",
+ e,
)
diff --git "a/pysqltools/\302\241" "b/pysqltools/\302\241"
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/test_connections.py b/tests/test_connections.py
index 3d06263..cf21670 100644
--- a/tests/test_connections.py
+++ b/tests/test_connections.py
@@ -1,3 +1,4 @@
+import sqlite3
from datetime import datetime
from unittest.mock import MagicMock, patch
@@ -5,9 +6,12 @@
import mysql.connector
import mysql.connector.cursor
import pandas as pd
+import pytest
from pysqltools.src.connection.connection import SQLConnection
from pysqltools.src.sql.insert import insert_pandas
+from pysqltools.src.sql.query import Query
+from pysqltools.src.sql.table import Table
df = pd.DataFrame(
{
@@ -31,8 +35,56 @@ def test_insert_with_conn(mock_connect, mock_cursor, mock_commit, mock_execute):
conn = SQLConnection(conn=mock_conn)
conn.conn.cursor = MagicMock(return_value=mysql.connector.cursor.MySQLCursor())
try:
- insert_pandas(df, conn, "myTable", "MySchema", batch_size=1)
+ insert_pandas(
+ df=df, connection=conn, table="myTable", schema="MySchema", batch_size=1
+ )
result = True
except:
result = False
assert result
+
+
+@pytest.mark.skip(reason="This test is skipped unconditionally")
+def test_sqlite():
+ conn = sqlite3.connect("tests/test_db.sqlite3")
+ connection = SQLConnection(conn=conn)
+
+ df = pd.DataFrame(
+ {
+ "id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ "amount": [
+ 111342,
+ 463463,
+ 6357,
+ 435765,
+ 757456,
+ 84678,
+ 34,
+ 7547,
+ 74567,
+ 76,
+ ],
+ "dt": [datetime.today() for _ in range(10)],
+ "strings": ["a", "b", "a", "a", "a", "a", "b", "a", "a", "a"],
+ }
+ )
+ table = Table("test_table")
+ table.create_from_df(
+ df=df,
+ execute=True,
+ insert_data=True,
+ connection=connection,
+ batch_size=1,
+ dialect="sqlite",
+ )
+
+ query = Query("select * from test_table")
+ data = connection.fetch(query, dataframe=True)
+ os.remove("tests/test_db.sqlite3")
+ data["dt"] = pd.to_datetime(data["dt"])
+ data = data.iloc[:10]
+ df["dt"] = pd.to_datetime(df["dt"])
+
+ assert [df[i].tolist() for i in df.columns] == [
+ data[i].tolist() for i in df.columns
+ ]
diff --git a/tests/test_queries.py b/tests/test_queries.py
index 9c95c05..af00b0c 100644
--- a/tests/test_queries.py
+++ b/tests/test_queries.py
@@ -83,9 +83,7 @@ def test_cte_replacement():
def test_create_table_string():
- expected = (
- "CREATE TABLE myTable ( col1 int, col11 double, col2 bool, col3 varchar )"
- )
+ expected = "CREATE TABLE IF NOT EXISTS myTable ( col1 int, col11 double, col2 bool, col3 varchar )"
with open("tests/queries/test_cte.sql", "r", encoding="utf-8") as f:
sql = f.read()