11import logging
2+ from typing import Any , Dict , Optional , Union
23
34import pandas as pd
45import psycopg2
56import psycopg2 .extras
67
78from .retry_decorator import retry
9+ from .utils .dataframe_functions import replace_nan_with_none_in_dataframe
810
911
1012class PostgresEngine :
11-
12- def __init__ (self , databaseName : str , user : str , password : str , host : str = 'localhost' , port : int = 5432 ):
13+ def __init__ (self , databaseName : str , user : str , password : str , host : str = "localhost" , port : int = 5432 ):
1314 """
1415 Class for accessing Postgres databases more easily.
1516
@@ -27,72 +28,79 @@ def __init__(self, databaseName: str, user: str, password: str, host: str = 'loc
2728 self .connection = None
2829 self .cursor = None
2930
30- def _get_connection (self ):
31+ def _get_connection (self ) -> None :
3132 try :
32- self .connection = psycopg2 .connect (user = self .user , password = self .password , host = self .host , port = self .port , database = self .databaseName )
33+ self .connection = psycopg2 .connect (
34+ user = self .user , password = self .password , host = self .host , port = self .port , database = self .databaseName
35+ )
3336 except Exception as ex :
34- logging .exception (f' Error connecting to PostgreSQL { ex } ' )
37+ logging .exception (f" Error connecting to PostgreSQL { ex } " )
3538 raise ex
3639
37- def _get_cursor (self , isInsertionQuery : bool ):
40+ def _get_cursor (self , isInsertionQuery : bool ) -> None :
3841 if isInsertionQuery :
3942 self .cursor = self .connection .cursor ()
4043 else :
4144 self .cursor = self .connection .cursor (cursor_factory = psycopg2 .extras .RealDictCursor )
4245
43- def _close_connection (self ):
46+ def _close_connection (self ) -> None :
4447 self .connection .close ()
4548
46- def _close_cursor (self ):
49+ def _close_cursor (self ) -> None :
4750 self .cursor .close ()
4851
49- def close (self ):
52+ def close (self ) -> None :
5053 if self .connection is not None :
5154 self ._close_connection ()
5255 if self .cursor is not None :
5356 self ._close_cursor ()
5457
55- def create_table (self , schema : str ):
58+ def create_table (self , schema : str ) -> None :
5659 self ._get_connection ()
5760 self ._get_cursor (isInsertionQuery = True )
5861 self .cursor .execute (schema )
5962 try :
6063 self .connection .commit ()
6164 except Exception as ex :
62- logging .exception (f' error: { ex } \n schemaQuery: { schema } ' )
65+ logging .exception (f" error: { ex } \n schemaQuery: { schema } " )
6366 raise ex
6467 finally :
6568 self .close ()
6669
67- def create_index (self , tableName : str , column : str ):
70+ def create_index (self , tableName : str , column : str ) -> None :
6871 self ._get_connection ()
6972 self ._get_cursor (isInsertionQuery = True )
70- indexQuery = f' CREATE INDEX IF NOT EXISTS { tableName } _{ column } ON { tableName } ({ column } );'
73+ indexQuery = f" CREATE INDEX IF NOT EXISTS { tableName } _{ column } ON { tableName } ({ column } );"
7174 self .cursor .execute (indexQuery )
7275 try :
7376 self .connection .commit ()
7477 except Exception as ex :
75- logging .exception (f' error: { ex } \n indexQuery: { indexQuery } ' )
78+ logging .exception (f" error: { ex } \n indexQuery: { indexQuery } " )
7679 raise ex
7780 finally :
7881 self .close ()
7982
8083 @retry (numRetries = 5 , retryDelaySeconds = 3 , backoffScalingFactor = 2 )
81- def run_select_query (self , query : str , parameters : dict = None ):
84+ def run_select_query_with_retry (self , query : str , parameters : Optional [Dict [str , Any ]] = None ) -> pd .DataFrame :
85+ return self .run_select_query (query = query , parameters = parameters )
86+
87+ def run_select_query (self , query : str , parameters : Optional [Dict [str , Any ]] = None ) -> pd .DataFrame :
8288 self ._get_connection ()
8389 self ._get_cursor (isInsertionQuery = False )
8490 self .cursor .execute (query , parameters )
8591 outputs = self .cursor .fetchall ()
8692 self .close ()
8793 outputDataframe = pd .DataFrame (outputs )
88- return outputDataframe . where ( outputDataframe . notnull (), None ). dropna ( axis = 0 , how = 'all' )
94+ return replace_nan_with_none_in_dataframe ( dataframe = outputDataframe )
8995
9096 @retry (numRetries = 5 , retryDelaySeconds = 3 , backoffScalingFactor = 2 )
91- def run_update_query (self , query : str , parameters : dict = None , returnId : bool = True ):
97+ def run_update_query (
98+ self , query : str , parameters : Optional [Dict [str , Any ]] = None , returnId : bool = True
99+ ) -> Union [None , int ]:
92100 self ._get_connection ()
93101 self ._get_cursor (isInsertionQuery = True )
94102 if returnId :
95- query = f' { query } \n RETURNING id'
103+ query = f" { query } \n RETURNING id"
96104 self .cursor .execute (query , parameters )
97105 if returnId :
98106 insertedId = self .cursor .fetchone ()[0 ]
@@ -101,7 +109,7 @@ def run_update_query(self, query: str, parameters: dict = None, returnId: bool =
101109 try :
102110 self .connection .commit ()
103111 except Exception as ex :
104- logging .exception (f' error: { ex } \n query: { query } \n parameters: { parameters } ' )
112+ logging .exception (f" error: { ex } \n query: { query } \n parameters: { parameters } " )
105113 raise ex
106114 finally :
107115 self .close ()
0 commit comments