Skip to content

Commit a96cb79

Browse files
authored
Use unified DB-API in codebase (#2821)
* Add maxcompute DB-API * remove unused import * format code * polish db-api * Adapt paiio with DB-API * Adapt paiio with DB-API * add try import paiio * use db-api in old code * DB-API support Python2 so can run on PAI * polish db-api to support Python2 so can run on PAI * polish db-api to support Python2 so can run on PAI * polish db-api to support Python2 so can run on PAI * Use unified DB-API in codebase. * Use unified DB-API in codebase. * polish code * remove debug info * fix ut
1 parent 6556299 commit a96cb79

File tree

23 files changed

+276
-781
lines changed

23 files changed

+276
-781
lines changed

python/runtime/db.py

Lines changed: 55 additions & 409 deletions
Large diffs are not rendered by default.

python/runtime/db_test.py

Lines changed: 77 additions & 192 deletions
Large diffs are not rendered by default.

python/runtime/db_writer/hive.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,8 @@ def __init__(self,
3939
self.hdfs_user = hdfs_user
4040
self.hdfs_pass = hdfs_pass
4141

42-
def _column_list(self):
43-
# NOTE(yancey1989): for the tablename: mydb.tblname, if 'mydb' is
44-
# a tablename in the default database, Hive describe STATEMENT would
45-
# mistake 'tblname' to a column name.
46-
cursor = self.conn.cursor()
47-
table_parts = self.table_name.split(".")
48-
if len(table_parts) == 2:
49-
db, table_name = table_parts[0], table_parts[1]
50-
cursor.execute("use %s" % db)
51-
cursor.execute("describe %s" % table_name)
52-
elif len(table_parts) == 1:
53-
cursor.execute("describe %s" % self.table_name)
54-
else:
55-
raise ValueError("")
56-
result = cursor.fetchall()
57-
cursor.execute("use %s " % self.conn.default_db)
58-
return result
59-
6042
def _indexing_table_schema(self, table_schema):
61-
column_list = self._column_list()
43+
column_list = self.conn.get_table_schema(self.table_name)
6244

6345
schema_idx = []
6446
idx_map = {}
@@ -113,12 +95,9 @@ def write_hive_table(self):
11395
cmd_namenode_str, self.tmp_f.name, hdfs_path, self.table_name)
11496
subprocess.check_output(cmd_str.split(), env=hdfs_envs)
11597
# load CSV into Hive
116-
cursor = self.conn.cursor()
11798
load_sql = "LOAD DATA INPATH '%s/%s/' OVERWRITE INTO TABLE %s" % (
11899
hdfs_path, self.table_name, self.table_name)
119-
cursor.execute(load_sql)
120-
self.conn.commit()
121-
cursor.close()
100+
self.conn.execute(load_sql)
122101

123102
# remove the temporary dir on hdfs
124103
cmd_str = "hdfs dfs %s -rm -r -f %s/%s/" % (cmd_namenode_str,

python/runtime/dbapi/connection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from six.moves.urllib.parse import parse_qs, urlparse
1818

1919

20-
@six.add_metaclass(ABCMeta)
21-
class ResultSet(object):
20+
class ResultSet(six.Iterator):
2221
"""Base class for DB query result, caller can iteratable this object
2322
to get all result rows"""
2423
def __init__(self):

python/runtime/dbapi/paiio.py

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ class PaiIOConnection(Connection):
7575
currently only support full-table reading. That means
7676
we can't filter the data, join the table and so on.
7777
The only supported query statement is `None`. The scheme
78-
part of the uri can be 'paiio' or 'odps'
78+
part of the uri can be 'paiio' or 'odps'.
79+
80+
A PaiIOConnection always binds to a specific table.
81+
Init PaiIOConnection do not establish any real connection,
82+
so, feel free to new a connection object when needed.
7983
8084
Typical use is:
8185
con = PaiIOConnection("paiio://db/tables/my_table")
@@ -85,14 +89,17 @@ class PaiIOConnection(Connection):
8589
def __init__(self, conn_uri):
8690
super(PaiIOConnection, self).__init__(conn_uri)
8791
# (TODO: lhw) change driver to paiio
88-
self.driver = "pai_maxcompute"
92+
self.driver = "paiio"
8993
match = re.findall(r"\w+://\w+/tables/(.+)", conn_uri)
9094
if len(match) < 1:
9195
raise ValueError("Should specify table in uri with format: "
92-
"paiio://db/tables/table?param_a=a&param_b=b")
93-
self.params["table"] = conn_uri.replace("paiio://", "odps://")
94-
self.params["slice_id"] = self.params.get("slice_id", 0)
95-
self.params["slice_count"] = self.params.get("slice_count", 1)
96+
"paiio://db/tables/table?param_a=a&param_b=b"
97+
"but get: %s" % conn_uri)
98+
99+
table = self.uripts._replace(scheme="odps", query="")
100+
self.params["table"] = table.geturl()
101+
self.params["slice_id"] = int(self.params.get("slice_id", "0"))
102+
self.params["slice_count"] = int(self.params.get("slice_count", "1"))
96103

97104
def _get_result_set(self, statement):
98105
if statement is not None:
@@ -106,44 +113,67 @@ def _get_result_set(self, statement):
106113
except Exception as e:
107114
return PaiIOResultSet(None, str(e))
108115

109-
def get_table_schema(self, full_uri):
110-
"""Get schema of given table, caller need to supply the full
111-
uri for paiio table, this is slight different with other connections.
112-
"""
113-
return PaiIOConnection.get_schema(full_uri)
114-
115116
def query(self, statement=None):
116117
return super(PaiIOConnection, self).query(statement)
117118

118-
@staticmethod
119-
def get_table_row_num(table_uri):
120-
"""Get row number of given table
121-
122-
Args:
123-
table_uri: the full uri for the table to get row from
119+
def get_table_row_num(self):
120+
"""Get row number of the binded table
124121
125122
Return:
126123
Number of rows in the table
127124
"""
128-
reader = paiio.TableReader(table_uri)
125+
reader = paiio.TableReader(self.params["table"])
129126
row_num = reader.get_row_count()
130127
reader.close()
131128
return row_num
132129

133-
@staticmethod
134-
def get_schema(table_uri):
135-
"""Get schema of the given table
136-
137-
Args:
138-
table_uri: the full uri for the table to get row from
130+
def get_schema(self):
131+
"""Get schema of the binded table
139132
140133
Returns:
141134
A list of column metas, like [(field_a, INT), (field_b, STRING)]
142135
"""
143-
rs = PaiIOConnection(table_uri).query()
136+
rs = self.query()
144137
col_info = rs.column_info()
145138
rs.close()
146139
return col_info
147140

141+
@staticmethod
142+
def from_table(table_name, slice_id=0, slice_count=1):
143+
"""Get a connection object from given table, if slice_count > 1
144+
then, bind to a table slice
145+
146+
Args:
147+
table_name: an odps table name in format: db.table
148+
slice_id: the slice id for binding
149+
slice_count: total slice count
150+
151+
Returns:
152+
A PaiIOConnection object
153+
"""
154+
uri = PaiIOConnection.get_uri_of_table(table_name, slice_id,
155+
slice_count)
156+
return PaiIOConnection(uri)
157+
158+
@staticmethod
159+
def get_uri_of_table(table_name, slice_id=0, slice_count=1):
160+
"""Get a connection object from a talbe name
161+
162+
Args:
163+
table_name: a table name in format: db.table
164+
slice_id: the slice id for binding
165+
slice_count: total slice count
166+
167+
Returns:
168+
A uri for the talbe slice with which we can get a connection
169+
by PaiIOConnection()
170+
"""
171+
pts = table_name.split(".")
172+
if len(pts) != 2:
173+
raise ValueError("paiio table name should in db.table format.")
174+
uri = "paiio://%s/tables/%s?slice_id=%d&slice_count=%d" % (
175+
pts[0], pts[1], slice_id, slice_count)
176+
return uri
177+
148178
def close(self):
149179
pass

python/runtime/model/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def write_with_generator(datasource, table, gen):
6060
_create_table(conn, table)
6161
idx = 0
6262

63-
with buffered_db_writer(conn.driver, conn, table, ["id", "block"]) as w:
63+
with buffered_db_writer(conn, table, ["id", "block"]) as w:
6464
for d in gen():
6565
block = base64.b64encode(d)
6666
row = [idx, block]

python/runtime/optimize/local.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,7 @@ def save_solved_result_in_db(solved_result, data_frame, variables,
231231
data_frame[result_value_name] = solved_result[0]
232232

233233
conn = db.connect_with_data_source(datasource)
234-
with db.buffered_db_writer(conn.driver, conn, result_table,
235-
column_names) as w:
234+
with db.buffered_db_writer(conn, result_table, column_names) as w:
236235
for i in six.moves.range(len(data_frame)):
237236
rows = list(data_frame.loc[i])
238237
w.write(rows)

python/runtime/pai/kmeans.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_train_kmeans_pai_cmd(datasource, model_name, data_table, model_attrs,
5252
]
5353

5454
conn = db.connect_with_data_source(datasource)
55-
db.execute(conn, "DROP TABLE IF EXISTS %s" % idx_table_name)
55+
conn.execute("DROP TABLE IF EXISTS %s" % idx_table_name)
5656

5757
return (
5858
"""pai -name kmeans -project algo_public """

python/runtime/pai/random_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_explain_random_forest_pai_cmd(datasource, model_name, data_table,
6666
conn = db.connect_with_data_source(datasource)
6767
schema = db.get_table_schema(conn, data_table)
6868
columns = [f[0] for f in schema]
69-
db.execute(conn, "DROP TABLE IF EXISTS %s;" % result_table)
69+
conn.execute("DROP TABLE IF EXISTS %s;" % result_table)
7070
return (
7171
"""pai -name feature_importance -project algo_public """
7272
"""-DmodelName="%s" -DinputTableName="%s" -DoutputTableName="%s" """

python/runtime/pai/submitter.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from os import path
2222

2323
from runtime import db
24+
from runtime.dbapi.maxcompute import MaxComputeConnection
2425
from runtime.diagnostics import SQLFlowDiagnostic
2526
from runtime.model import EstimatorType, oss
2627
from runtime.pai import cluster_conf
@@ -47,6 +48,7 @@
4748
XGB_REQUIREMENT = TF_REQUIREMENT + """
4849
xgboost==0.82
4950
sklearn2pmml==0.56.0
51+
sklearn_pandas==1.6.0
5052
"""
5153

5254

@@ -93,7 +95,7 @@ def create_tmp_table_from_select(select, datasource):
9395
tmp_tb_name, LIFECYCLE_ON_TMP_TABLE, select)
9496
# (NOTE: lhw) maxcompute conn doesn't support close
9597
# we should unify db interface
96-
if not db.execute(conn, create_sql):
98+
if not conn.execute(create_sql):
9799
raise SQLFlowDiagnostic("Can't crate tmp table for %s" % select)
98100
return "%s.%s" % (project, tmp_tb_name)
99101

@@ -105,7 +107,7 @@ def drop_tables(tables, datasource):
105107
for table in tables:
106108
if table != "":
107109
drop_sql = "DROP TABLE IF EXISTS %s" % table
108-
db.execute(conn, drop_sql)
110+
conn.execute(drop_sql)
109111
except: # noqa: E722
110112
# odps will clear table itself, so even fail here, we do
111113
# not need to raise error
@@ -130,15 +132,18 @@ def get_oss_model_url(model_full_path):
130132
return "oss://%s/%s" % (oss.SQLFLOW_MODELS_BUCKET, model_full_path)
131133

132134

135+
def parse_maxcompute_dsn(datasource):
136+
return MaxComputeConnection.get_uri_parts(datasource)
137+
138+
133139
def drop_pai_model(datasource, model_name):
134140
"""Drop PAI model
135141
136142
Args:
137143
datasource: current datasource
138144
model_name: name of the model to drop
139145
"""
140-
dsn = get_datasource_dsn(datasource)
141-
user, passwd, address, database = db.parseMaxComputeDSN(dsn)
146+
user, passwd, address, database = parse_maxcompute_dsn(datasource)
142147
cmd = "drop offlinemodel if exists %s" % model_name
143148
subprocess.run([
144149
"odpscmd", "-u", user, "-p", passwd, "--project", database,
@@ -215,8 +220,7 @@ def submit_pai_task(pai_cmd, datasource):
215220
pai_cmd: The command to submit
216221
datasource: The datasource this cmd will manipulate
217222
"""
218-
dsn = get_datasource_dsn(datasource)
219-
user, passwd, address, project = db.parseMaxComputeDSN(dsn)
223+
user, passwd, address, project = parse_maxcompute_dsn(datasource)
220224
cmd = [
221225
"odpscmd", "--instance-priority", "9", "-u", user, "-p", passwd,
222226
"--project", project, "--endpoint", address, "-e", pai_cmd
@@ -230,8 +234,7 @@ def submit_pai_task(pai_cmd, datasource):
230234
def get_oss_model_save_path(datasource, model_name):
231235
if not model_name:
232236
return None
233-
dsn = get_datasource_dsn(datasource)
234-
user, _, _, project = db.parseMaxComputeDSN(dsn)
237+
user, _, _, project = parse_maxcompute_dsn(datasource)
235238
user = user or "unknown"
236239
return "/".join([project, user, model_name])
237240

@@ -246,8 +249,7 @@ def get_project(datasource):
246249
Args:
247250
datasource: The odps url to extract project
248251
"""
249-
dsn = get_datasource_dsn(datasource)
250-
_, _, _, project = db.parseMaxComputeDSN(dsn)
252+
_, _, _, project = parse_maxcompute_dsn(datasource)
251253
return project
252254

253255

@@ -547,14 +549,14 @@ def create_predict_result_table(datasource, select, result_table, label_column,
547549
model_type: type of model defined in runtime.model.oss
548550
"""
549551
conn = db.connect_with_data_source(datasource)
550-
db.execute(conn, "DROP TABLE IF EXISTS %s" % result_table)
552+
conn.execute("DROP TABLE IF EXISTS %s" % result_table)
551553
# PAI ml will create result table itself
552554
if model_type == EstimatorType.PAIML:
553555
return
554556

555557
create_table_sql = "CREATE TABLE %s AS SELECT * FROM %s LIMIT 0" % (
556558
result_table, select)
557-
db.execute(conn, create_table_sql)
559+
conn.execute(create_table_sql)
558560

559561
# if label is not in data table, add a int column for it
560562
schema = db.get_table_schema(conn, result_table)
@@ -565,11 +567,11 @@ def create_predict_result_table(datasource, select, result_table, label_column,
565567
break
566568
col_names = [col[0] for col in schema]
567569
if label_column not in col_names:
568-
db.execute(
570+
conn.execute(
569571
conn, "ALTER TABLE %s ADD %s %s" %
570572
(result_table, label_column, col_type))
571573
if train_label_column != label_column and train_label_column in col_names:
572-
db.execute(
574+
conn.execute(
573575
conn, "ALTER TABLE %s DROP COLUMN %s" %
574576
(result_table, train_label_column))
575577

@@ -668,7 +670,7 @@ def create_explain_result_table(datasource, data_table, result_table,
668670
"""
669671
conn = db.connect_with_data_source(datasource)
670672
drop_stmt = "DROP TABLE IF EXISTS %s" % result_table
671-
db.execute(conn, drop_stmt)
673+
conn.execute(drop_stmt)
672674

673675
create_stmt = ""
674676
if model_type == EstimatorType.PAIML:
@@ -703,7 +705,7 @@ def create_explain_result_table(datasource, data_table, result_table,
703705
"not supported modelType %d for creating Explain result table" %
704706
model_type)
705707

706-
if not db.execute(conn, create_stmt):
708+
if not conn.execute(create_stmt):
707709
raise SQLFlowDiagnostic("Can't create explain result table")
708710

709711

@@ -731,7 +733,7 @@ def get_explain_random_forests_cmd(datasource, model_name, data_table,
731733

732734
conn = db.connect_with_data_source(datasource)
733735
# drop result table if exists
734-
db.execute(conn, "DROP TABLE IF EXISTS %s;" % result_table)
736+
conn.execute("DROP TABLE IF EXISTS %s;" % result_table)
735737
schema = db.get_table_schema(conn, data_table)
736738
fields = [f[0] for f in schema if f[0] != label_column]
737739
return ('''pai -name feature_importance -project algo_public '''
@@ -846,7 +848,7 @@ def create_evaluate_result_table(datasource, result_table, metrics):
846848
sql = "CREATE TABLE IF NOT EXISTS %s (%s);" % (result_table,
847849
",".join(fields))
848850
conn = db.connect_with_data_source(datasource)
849-
db.execute(conn, sql)
851+
conn.execute(sql)
850852

851853

852854
def submit_pai_evaluate(datasource, model_name, select, result_table,

0 commit comments

Comments
 (0)