Skip to content

Commit 05b2d77

Browse files
committed
Simplified DBClient.__init__().
1 parent e4aa2b8 commit 05b2d77

File tree

2 files changed

+86
-70
lines changed

2 files changed

+86
-70
lines changed

OracleClient.py

Lines changed: 81 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
88
AUTHOR: David J. Lambert
99
10-
VERSION: 0.4.4
10+
VERSION: 0.4.6
1111
1212
DATE: Mar 12, 2020
1313
@@ -16,24 +16,36 @@
1616
test it, and document it. I also demonstrate I know relational databases.
1717
1818
DESCRIPTION:
19-
Class DBInstance contains all the information needed to log into an Oracle
20-
instance, plus it contains the connection handle to that database.
19+
Class DBInstance encapsulates all the information needed to log into an Oracle
20+
instance, plus it encapsulates the connection handle to that database.
21+
Its externally useful methods are:
22+
1) print_all_connection_parameters: prints all the connection parameters.
23+
2) close_connection: closes the connect to the database.
24+
3) print_connection_status: whether or not DBInstance is connected to the db.
2125
2226
Class DBClient executes SQL with bind variables, and then prints the results.
2327
Its externally useful methods are:
24-
1) set_sql: gets the text of SQL to run, including bind variables.
25-
2) get_sql: reads text of SQL to run from a prompt.
26-
3) get_bind_vars: gets the bind variable names and values from a prompt.
28+
1) set_sql: gets SQL and bind variables as arguments.
29+
2) set_sql_at_prompt: reads SQL from a prompt.
30+
3) set_bind_vars_at_prompt: reads bind variable names & values from a prompt.
2731
4) run_sql: executes SQL, which was read as a text variable (with set_sql)
28-
or entered at a prompt (by get_sql and get_bind_vars).
32+
or entered at a prompt (by set_sql_at_prompt and set_bind_vars_at_prompt).
2933
5) database_table_schema: lists all the tables owned by the current login,
3034
all the columns in those tables, and all indexes on those tables.
3135
6) database_view_schema: lists all the views owned by the current login,
3236
all the columns in those views, and the SQL for the view.
3337
3438
Class OutputWriter handles all query output to file or to standard output.
39+
Its externally useful methods are:
40+
1) get_align_col: whether or not to align columns in output.
41+
2) get_col_sep: get the character(s) to separate columns with.
42+
3) get_out_file_name: get location to write output to (file or standard out).
43+
4) write_rows: write output to location chosen in get_out_file_name.
44+
5) olose_output_file: if writing to output file, close it.
3545
36-
Stand-alone Method run_sql_cmdline runs sqlplus as a subprocess.
46+
Stand-alone function run_sql_cmdline runs sqlplus as a subprocess.
47+
Stand-alone function ask_for_password(username) prompts for the password for
48+
the username provided as an argument.
3749
3850
The code has been tested with CRUD statements (Create, Read, Update, Delete).
3951
There is nothing to prevent the end-user from entering other SQL, such as
@@ -111,7 +123,7 @@ class DBInstance(object):
111123
hostname (str): the hostname of this database.
112124
port_num (int): the port this database listens on.
113125
instance (str): the name of this database instance.
114-
db_library: reference to the library that was imported for this db_type.
126+
db_library (object): library object that was imported for this db_type.
115127
db_library_name (str): name of the library imported for this db_type.
116128
connection: the handle to this database. I set connection = None when
117129
connection closed, this is not default behavior.
@@ -138,58 +150,53 @@ def __init__(self, db_type: str, username: str, password: str,
138150
sqlite = 'sqlite'
139151
db_types = [oracle, mysql, sql_server, postgresql, access, sqlite]
140152

141-
uses_conn_str = set(db_types) - {mysql}
153+
uses_connection_string = set(db_types) - {mysql}
142154

143155
# Check if db_type valid.
144156
if db_type not in db_types:
145157
print('Invalid database type "{}".'.format(db_type))
158+
exit(1)
146159

147-
# Libraries for supported database types.
148-
map_type_to_lib = {oracle: 'cx_Oracle',
149-
mysql: 'pymysql',
150-
sql_server: 'pyodbc',
151-
postgresql: 'psycopg2',
152-
access: 'pyodbc',
153-
sqlite: 'sqlite3'}
160+
# Library names for supported database types.
161+
db_libraries = {oracle: 'cx_Oracle', mysql: 'pymysql',
162+
sql_server: 'pyodbc', postgresql: 'psycopg2',
163+
access: 'pyodbc', sqlite: 'sqlite3'}
154164

155165
# Import appropriate library.
156-
db_library = __import__(map_type_to_lib[db_type])
166+
db_library = __import__(db_libraries[db_type])
157167

158-
# Database connection string.
159-
conn_str = ''
168+
# Form database connection string.
160169
z = ''
161170
if db_type == mysql:
162-
pass
171+
z = ''
163172
elif db_type == sql_server:
164-
conn_str = r'DRIVER={SQL Server};'
165-
z = r'UID={};PWD={};SERVER={};PORT={};DATABASE={}'
173+
z = ('DRIVER={{SQL Server}};UID={};PWD={};SERVER={};PORT={};'
174+
'DATABASE={}')
166175
elif db_type == oracle:
167176
z = '{}/{}@{}:{}/{}'
168177
elif db_type == postgresql:
169178
z = "user='{}' password='{}' host='{}' port='{}' dbname='{}'"
170179
elif db_type == access:
171-
conn_str = r'DRIVER={Microsoft Access Driver (*.mdb, *.accdb)};'
172-
z = r'DBQ={};'
180+
z = 'DRIVER={{Microsoft Access Driver (*.mdb, *.accdb)}};DBQ={};'
173181
elif db_type == sqlite:
174182
z = '{}'
175183
else:
176184
print('Unknown db type {}, aborting.'.format(db_type))
185+
exit(1)
177186

178187
if db_type in {sql_server, oracle, postgresql}:
179-
conn_str += z.format(username, password, hostname, port_num,
180-
instance)
188+
z = z.format(username, password, hostname, port_num, instance)
181189
elif db_type in {sqlite, access}:
182-
conn_str += z.format(instance)
190+
z = z.format(instance)
183191

184192
# Connect to database instance.
185193
self.connection = None
186194
try:
187-
if db_type in uses_conn_str:
188-
self.connection = db_library.connect(conn_str)
195+
if db_type in uses_connection_string:
196+
self.connection = db_library.connect(z)
189197
else:
190-
self.connection = db_library.connect(username, password,
191-
hostname, port_num,
192-
instance)
198+
self.connection = db_library.\
199+
connect(username, password, hostname, port_num, instance)
193200
print('Successfully connected to database.')
194201
except db_library.Error:
195202
print_stacktrace()
@@ -204,7 +211,7 @@ def __init__(self, db_type: str, username: str, password: str,
204211
self.port_num: int = port_num
205212
self.instance: str = instance
206213
self.db_library = db_library
207-
self.db_library_name = map_type_to_lib[db_type]
214+
self.db_library_name = db_libraries[db_type]
208215

209216
return
210217

@@ -266,18 +273,28 @@ def get_db_library_name(self) -> str:
266273
267274
Parameters:
268275
Returns:
269-
type (str): database software type.
276+
db_library_name (str): database library name.
270277
"""
271278
return self.db_library_name
272279

280+
def get_db_type(self) -> str:
281+
""" Method to return the database type.
282+
283+
Parameters:
284+
Returns:
285+
db_type (str): database software type.
286+
"""
287+
return self.db_type
288+
273289
def print_all_connection_parameters(self) -> None:
274290
""" Method that executes all print methods of this class.
275291
276292
Parameters:
277293
Returns:
278294
"""
279295
print('The database type is "{}".'.format(self.db_type))
280-
print('The database software version is "{}".'.format(self.connection.version))
296+
print('The database software version is "{}".'.format(
297+
self.connection.version))
281298
print('The database username is "{}".'.format(self.username))
282299
print('The database hostname is "{}".'.format(self.hostname))
283300
print('The database port number is {}.'.format(self.port_num))
@@ -336,7 +353,7 @@ def __init__(self, out_file_name: str = '', align_col: bool = True,
336353
self.col_sep: str = col_sep
337354
return
338355

339-
def clean_up(self):
356+
def close_output_file(self):
340357
""" Close output file, if it exists.
341358
342359
Parameters:
@@ -356,7 +373,9 @@ def get_align_col(self):
356373
# Keep looping until have acceptable answer.
357374
while True:
358375
response = input(prompt).strip().upper()
359-
if response in 'YT1':
376+
if response == '':
377+
print('Invalid answer, please try again.')
378+
elif response in 'YT1':
360379
print('You chose to align columns.')
361380
self.align_col = True
362381
break
@@ -490,19 +509,25 @@ class DBClient(object):
490509
cursor: the cursor to execute this SQL on.
491510
I set cursor = None when cursor closed.
492511
"""
493-
def __init__(self, db_instance, db_type, db_library_name) -> None:
512+
def __init__(self, db_instance) -> None:
494513
""" Constructor method for this class.
495514
496515
Parameters:
497516
db_instance: the handle for a database instance to use.
498517
Returns:
499518
"""
500-
self.db_type = db_type
501-
db_library = __import__(db_library_name)
519+
# Get database cursor.
520+
self.db_instance = db_instance
521+
522+
# Get database type.
523+
self.db_type = self.db_instance.get_db_type()
524+
525+
# Get database library.
526+
self.db_library_name = self.db_instance.get_db_library_name()
527+
db_library = __import__(self.db_library_name)
502528
self.db_library = db_library
503529

504-
# Get database instance object and cursor.
505-
self.db_instance = db_instance
530+
# Get database cursor.
506531
self.cursor = self.db_instance.create_cursor()
507532

508533
# Placeholders for SQL text and bind variables.
@@ -538,8 +563,8 @@ def set_sql(self, sql: str, bind_vars: dict) -> None:
538563
self.bind_vars: dict = bind_vars
539564
return
540565

541-
def get_sql(self) -> None:
542-
""" Get text of SQL at a prompt.
566+
def set_sql_at_prompt(self) -> None:
567+
""" Set text of SQL at a prompt.
543568
544569
Parameters:
545570
Returns:
@@ -568,8 +593,8 @@ def get_sql(self) -> None:
568593
self.sql = sql.strip()
569594
return
570595

571-
def get_bind_vars(self) -> None:
572-
""" Get bind variables at a prompt.
596+
def set_bind_vars_at_prompt(self) -> None:
597+
""" Set bind variables at a prompt.
573598
574599
Parameters:
575600
Returns:
@@ -941,7 +966,7 @@ def database_table_schema(self, colsep='|') -> None:
941966
# Print output.
942967
print()
943968
output_writerx.write_rows(indexes_rows, indexes_col_names)
944-
output_writerx.clean_up()
969+
output_writerx.close_output_file()
945970
return
946971

947972
def database_view_schema(self, colsep='|') -> None:
@@ -1000,7 +1025,7 @@ def database_view_schema(self, colsep='|') -> None:
10001025
# Find and print columns in this view.
10011026
columns_col_names, columns_rows = self.find_view_columns(my_view_name)
10021027
output_writerx.write_rows(columns_rows, columns_col_names)
1003-
output_writerx.clean_up()
1028+
output_writerx.close_output_file()
10041029
return
10051030

10061031
# -------- CUSTOM STAND-ALONE FUNCTIONS
@@ -1134,11 +1159,10 @@ def run_sql_cmdline(sql: str) -> list:
11341159
db_instance1 = DBInstance(db_type1, username1, password1, hostname1,
11351160
port_num1, instance1)
11361161
db_instance1.print_all_connection_parameters()
1137-
db_library_name1 = db_instance1.get_db_library_name()
11381162

11391163
# CREATE DATABASE CLIENT OBJECT.
11401164
print('\nGETTING DATABASE CLIENT OBJECT...')
1141-
my_db_client = DBClient(db_instance1, db_type1, db_library_name1)
1165+
my_db_client = DBClient(db_instance1)
11421166

11431167
# See the database table schema for my login.
11441168
print('\nSEE THE COLUMNS AND INDEXES OF ONE TABLE...')
@@ -1166,8 +1190,7 @@ def run_sql_cmdline(sql: str) -> list:
11661190

11671191
# Set up to write output.
11681192
print('\nPREPARING TO FORMAT THAT OUTPUT, AND PRINT OR WRITE IT TO FILE.')
1169-
output_writer = OutputWriter(out_file_name='', align_col=True,
1170-
col_sep='|')
1193+
output_writer = OutputWriter(out_file_name='', align_col=True, col_sep='|')
11711194
output_writer.get_align_col()
11721195
output_writer.get_col_sep()
11731196
output_writer.get_out_file_name()
@@ -1177,30 +1200,29 @@ def run_sql_cmdline(sql: str) -> list:
11771200

11781201
# Clean up.
11791202
print('\n\nOK, FORGET ALL THAT.')
1180-
output_writer.clean_up()
1203+
output_writer.close_output_file()
11811204
col_names1 = None
11821205
rows1 = None
11831206

11841207
# From a prompt, read in SQL & dict of bind variables & their values.
1185-
my_db_client.get_sql()
1186-
my_db_client.get_bind_vars()
1208+
my_db_client.set_sql_at_prompt()
1209+
my_db_client.set_bind_vars_at_prompt()
11871210

11881211
# Execute the SQL.
11891212
print('\nGETTING THE OUTPUT OF THAT SQL:')
11901213
col_names2, rows2, row_count2 = my_db_client.run_sql()
11911214

11921215
# Set up to write output.
11931216
print('\nPREPARING TO FORMAT THAT OUTPUT, AND PRINT OR WRITE IT TO FILE.')
1194-
output_writer = OutputWriter(out_file_name='', align_col=True,
1195-
col_sep='|')
1217+
output_writer = OutputWriter(out_file_name='', align_col=True, col_sep='|')
11961218
output_writer.get_align_col()
11971219
output_writer.get_col_sep()
11981220
output_writer.get_out_file_name()
11991221
# Show the results.
12001222
output_writer.write_rows(rows2, col_names2)
12011223

12021224
# Clean up.
1203-
output_writer.clean_up()
1225+
output_writer.close_output_file()
12041226
col_names2 = None
12051227
rows2 = None
12061228
print()

universalClient.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -485,34 +485,28 @@ def connect_to_db(db_type, db_host, db_port, db_instance, db_path, db_user,
485485
db_library = __import__(map_type_to_lib[db_type])
486486

487487
if db_type == mysql:
488-
conn_str = ''
489488
z = ''
490489
elif db_type == sql_server:
491-
conn_str = r'DRIVER={SQL Server};'
492-
z = r'UID={};PWD={};SERVER={};PORT={};DATABASE={}'
490+
z = 'DRIVER={{SQL Server}};UID={};PWD={};SERVER={};PORT={};DATABASE={}'
493491
elif db_type == oracle:
494-
conn_str = ''
495492
z = '{}/{}@{}:{}/{}'
496493
elif db_type == postgresql:
497-
conn_str = ''
498494
z = "user='{}' password='{}' host='{}' port='{}' dbname='{}'"
499495
elif db_type == access:
500-
conn_str = r'DRIVER={Microsoft Access Driver (*.mdb, *.accdb)};'
501-
z = r'DBQ={};'
496+
z = 'DRIVER={{Microsoft Access Driver (*.mdb, *.accdb)}};DBQ={};'
502497
elif db_type == sqlite:
503-
conn_str = ''
504498
z = '{}'
505499
else:
506500
print('Unknown db type {}, aborting.'.format(db_type))
507501
raise ExceptionUserAnotherDB()
508502

509503
if db_type in {sql_server, oracle, postgresql}:
510-
conn_str += z.format(db_user, db_password, db_host, db_port, db_instance)
504+
z = z.format(db_user, db_password, db_host, db_port, db_instance)
511505
elif db_type in db_local:
512-
conn_str += z.format(db_path)
506+
z = z.format(db_path)
513507

514508
if db_type in db_uses_conn_str:
515-
connection = db_library.connect(conn_str)
509+
connection = db_library.connect(z)
516510
else:
517511
connection = db_library.connect(db_host, db_user, db_password,
518512
db_instance, db_port)

0 commit comments

Comments
 (0)