2121from os import path
2222
2323from runtime import db
24+ from runtime .dbapi .maxcompute import MaxComputeConnection
2425from runtime .diagnostics import SQLFlowDiagnostic
2526from runtime .model import EstimatorType , oss
2627from runtime .pai import cluster_conf
4748XGB_REQUIREMENT = TF_REQUIREMENT + """
4849xgboost==0.82
4950sklearn2pmml==0.56.0
51+ sklearn_pandas==1.6.0
5052"""
5153
5254
@@ -93,7 +95,7 @@ def create_tmp_table_from_select(select, datasource):
9395 tmp_tb_name , LIFECYCLE_ON_TMP_TABLE , select )
9496 # (NOTE: lhw) maxcompute conn doesn't support close
9597 # we should unify db interface
96- if not db .execute (conn , create_sql ):
98+ if not conn .execute (create_sql ):
9799 raise SQLFlowDiagnostic ("Can't crate tmp table for %s" % select )
98100 return "%s.%s" % (project , tmp_tb_name )
99101
@@ -105,7 +107,7 @@ def drop_tables(tables, datasource):
105107 for table in tables :
106108 if table != "" :
107109 drop_sql = "DROP TABLE IF EXISTS %s" % table
108- db .execute (conn , drop_sql )
110+ conn .execute (drop_sql )
109111 except : # noqa: E722
110112 # odps will clear table itself, so even fail here, we do
111113 # not need to raise error
@@ -130,15 +132,18 @@ def get_oss_model_url(model_full_path):
130132 return "oss://%s/%s" % (oss .SQLFLOW_MODELS_BUCKET , model_full_path )
131133
132134
135+ def parse_maxcompute_dsn (datasource ):
136+ return MaxComputeConnection .get_uri_parts (datasource )
137+
138+
133139def drop_pai_model (datasource , model_name ):
134140 """Drop PAI model
135141
136142 Args:
137143 datasource: current datasource
138144 model_name: name of the model to drop
139145 """
140- dsn = get_datasource_dsn (datasource )
141- user , passwd , address , database = db .parseMaxComputeDSN (dsn )
146+ user , passwd , address , database = parse_maxcompute_dsn (datasource )
142147 cmd = "drop offlinemodel if exists %s" % model_name
143148 subprocess .run ([
144149 "odpscmd" , "-u" , user , "-p" , passwd , "--project" , database ,
@@ -215,8 +220,7 @@ def submit_pai_task(pai_cmd, datasource):
215220 pai_cmd: The command to submit
216221 datasource: The datasource this cmd will manipulate
217222 """
218- dsn = get_datasource_dsn (datasource )
219- user , passwd , address , project = db .parseMaxComputeDSN (dsn )
223+ user , passwd , address , project = parse_maxcompute_dsn (datasource )
220224 cmd = [
221225 "odpscmd" , "--instance-priority" , "9" , "-u" , user , "-p" , passwd ,
222226 "--project" , project , "--endpoint" , address , "-e" , pai_cmd
@@ -230,8 +234,7 @@ def submit_pai_task(pai_cmd, datasource):
230234def get_oss_model_save_path (datasource , model_name ):
231235 if not model_name :
232236 return None
233- dsn = get_datasource_dsn (datasource )
234- user , _ , _ , project = db .parseMaxComputeDSN (dsn )
237+ user , _ , _ , project = parse_maxcompute_dsn (datasource )
235238 user = user or "unknown"
236239 return "/" .join ([project , user , model_name ])
237240
@@ -246,8 +249,7 @@ def get_project(datasource):
246249 Args:
247250 datasource: The odps url to extract project
248251 """
249- dsn = get_datasource_dsn (datasource )
250- _ , _ , _ , project = db .parseMaxComputeDSN (dsn )
252+ _ , _ , _ , project = parse_maxcompute_dsn (datasource )
251253 return project
252254
253255
@@ -547,14 +549,14 @@ def create_predict_result_table(datasource, select, result_table, label_column,
547549 model_type: type of model defined in runtime.model.oss
548550 """
549551 conn = db .connect_with_data_source (datasource )
550- db .execute (conn , "DROP TABLE IF EXISTS %s" % result_table )
552+ conn .execute ("DROP TABLE IF EXISTS %s" % result_table )
551553 # PAI ml will create result table itself
552554 if model_type == EstimatorType .PAIML :
553555 return
554556
555557 create_table_sql = "CREATE TABLE %s AS SELECT * FROM %s LIMIT 0" % (
556558 result_table , select )
557- db .execute (conn , create_table_sql )
559+ conn .execute (create_table_sql )
558560
559561 # if label is not in data table, add a int column for it
560562 schema = db .get_table_schema (conn , result_table )
@@ -565,11 +567,11 @@ def create_predict_result_table(datasource, select, result_table, label_column,
565567 break
566568 col_names = [col [0 ] for col in schema ]
567569 if label_column not in col_names :
568- db .execute (
570+ conn .execute (
569571 conn , "ALTER TABLE %s ADD %s %s" %
570572 (result_table , label_column , col_type ))
571573 if train_label_column != label_column and train_label_column in col_names :
572- db .execute (
574+ conn .execute (
573575 conn , "ALTER TABLE %s DROP COLUMN %s" %
574576 (result_table , train_label_column ))
575577
@@ -668,7 +670,7 @@ def create_explain_result_table(datasource, data_table, result_table,
668670 """
669671 conn = db .connect_with_data_source (datasource )
670672 drop_stmt = "DROP TABLE IF EXISTS %s" % result_table
671- db .execute (conn , drop_stmt )
673+ conn .execute (drop_stmt )
672674
673675 create_stmt = ""
674676 if model_type == EstimatorType .PAIML :
@@ -703,7 +705,7 @@ def create_explain_result_table(datasource, data_table, result_table,
703705 "not supported modelType %d for creating Explain result table" %
704706 model_type )
705707
706- if not db .execute (conn , create_stmt ):
708+ if not conn .execute (create_stmt ):
707709 raise SQLFlowDiagnostic ("Can't create explain result table" )
708710
709711
@@ -731,7 +733,7 @@ def get_explain_random_forests_cmd(datasource, model_name, data_table,
731733
732734 conn = db .connect_with_data_source (datasource )
733735 # drop result table if exists
734- db .execute (conn , "DROP TABLE IF EXISTS %s;" % result_table )
736+ conn .execute ("DROP TABLE IF EXISTS %s;" % result_table )
735737 schema = db .get_table_schema (conn , data_table )
736738 fields = [f [0 ] for f in schema if f [0 ] != label_column ]
737739 return ('''pai -name feature_importance -project algo_public '''
@@ -846,7 +848,7 @@ def create_evaluate_result_table(datasource, result_table, metrics):
846848 sql = "CREATE TABLE IF NOT EXISTS %s (%s);" % (result_table ,
847849 "," .join (fields ))
848850 conn = db .connect_with_data_source (datasource )
849- db .execute (conn , sql )
851+ conn .execute (sql )
850852
851853
852854def submit_pai_evaluate (datasource , model_name , select , result_table ,
0 commit comments