Skip to content

Commit 94a63df

Browse files
committed
polish db-api to support Python2 so can run on PAI
1 parent 5fdff8e commit 94a63df

File tree

8 files changed

+118
-48
lines changed

8 files changed

+118
-48
lines changed

python/runtime/dbapi/connection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# limitations under the License
1313

1414
from abc import ABCMeta, abstractmethod
15-
from urllib.parse import parse_qs, urlparse
1615

1716
import six
17+
from six.moves.urllib.parse import parse_qs, urlparse
1818

1919

2020
@six.add_metaclass(ABCMeta)
@@ -102,6 +102,11 @@ def __init__(self, conn_uri):
102102
if len(l) == 1:
103103
self.params[k] = l[0]
104104

105+
def param(self, param_name, default_value=""):
106+
if not self.params:
107+
return default_value
108+
return self.params.get(param_name, default_value)
109+
105110
def _parse_uri(self):
106111
"""Parse the connection string into URI parts
107112
Returns:
@@ -139,7 +144,7 @@ def query(self, statement):
139144
"""
140145
return self._get_result_set(statement)
141146

142-
def exec(self, statement):
147+
def execute(self, statement):
143148
"""Execute given statement and return True on success
144149
145150
Args:

python/runtime/dbapi/hive.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License
1313

14-
from impala.dbapi import connect
14+
try:
15+
from impala.dbapi import connect
16+
except:
17+
pass
1518
from runtime.dbapi.connection import Connection, ResultSet
1619

1720

1821
class HiveResultSet(ResultSet):
1922
def __init__(self, cursor, err=None):
20-
super().__init__()
23+
super(HiveResultSet, self).__init__()
2124
self._cursor = cursor
2225
self._column_info = None
2326
self._err = err
@@ -33,7 +36,7 @@ def column_info(self):
3336
"""
3437

3538
if self._column_info is not None:
36-
return self.column_info
39+
return self._column_info
3740

3841
columns = []
3942
for desc in self._cursor.description:
@@ -66,7 +69,7 @@ class HiveConnection(Connection):
6669
configuration
6770
"""
6871
def __init__(self, conn_uri):
69-
super().__init__(conn_uri)
72+
super(HiveConnection, self).__init__(conn_uri)
7073
self.driver = "hive"
7174
self.params["database"] = self.uripts.path.strip("/")
7275
self._conn = connect(user=self.uripts.username,
@@ -87,6 +90,16 @@ def _get_result_set(self, statement):
8790
cursor.close()
8891
return HiveResultSet(None, str(e))
8992

93+
def cursor(self):
94+
"""Get a cursor on the connection
95+
We insist not to use the low level api like cursor.
96+
Instead, we can directly use query/exec
97+
"""
98+
return self._conn.cursor()
99+
100+
def commit(self):
101+
return self._conn.commit()
102+
90103
def close(self):
91104
if self._conn:
92105
self._conn.close()

python/runtime/dbapi/hive_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def test_query(self):
5151

5252
def test_exec(self):
5353
conn = HiveConnection(testing.get_datasource())
54-
rs = conn.exec("create table test_exec(a int)")
54+
rs = conn.execute("create table test_exec(a int)")
5555
self.assertTrue(rs)
56-
rs = conn.exec("insert into test_exec values(1), (2)")
56+
rs = conn.execute("insert into test_exec values(1), (2)")
5757
self.assertTrue(rs)
5858
rs = conn.query("select * from test_exec")
5959
self.assertTrue(rs.success())
6060
rows = [r for r in rs]
6161
self.assertTrue(2, len(rows))
62-
rs = conn.exec("drop table test_exec")
62+
rs = conn.execute("drop table test_exec")
6363
self.assertTrue(rs)
6464

6565
def test_get_table_schema(self):

python/runtime/dbapi/maxcompute.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313

1414
from odps import ODPS, tunnel
1515
from runtime.dbapi.connection import Connection, ResultSet
16+
from six.moves.urllib.parse import parse_qs, urlparse
17+
18+
COMPRESS_ODPS_ZLIB = tunnel.CompressOption.CompressAlgorithm.ODPS_ZLIB
1619

1720

1821
class MaxComputeResultSet(ResultSet):
1922
"""MaxCompute query result"""
2023
def __init__(self, instance, err=None):
21-
super().__init__()
24+
super(MaxComputeResultSet, self).__init__()
2225
self._instance = instance
2326
self._column_info = None
2427
self._err = err
@@ -40,18 +43,17 @@ def column_info(self):
4043
A list of column metas, like [(field_a, INT), (field_b, STRING)]
4144
"""
4245
if self._column_info is not None:
43-
return self.column_info
46+
return self._column_info
4447

4548
r = self._open_reader()
46-
self._column_info = [(col.name, str.upper(col.type))
49+
self._column_info = [(col.name, str(col.type).upper())
4750
for col in r._schema.columns]
4851
return self._column_info
4952

5053
def _open_reader(self):
5154
if not self._reader:
52-
compress = tunnel.CompressOption.CompressAlgorithm.ODPS_ZLIB
53-
self._reader = self._instance.open_reader(tunnel=True,
54-
compress_option=compress)
55+
self._reader = self._instance.open_reader(
56+
tunnel=True, compress_option=COMPRESS_ODPS_ZLIB)
5557
return self._reader
5658

5759
def success(self):
@@ -81,18 +83,33 @@ class MaxComputeConnection(Connection):
8183
maxcompute://access_id:access_key@service.com/api?curr_project=test_ci&scheme=http
8284
"""
8385
def __init__(self, conn_uri):
84-
super().__init__(conn_uri)
86+
super(MaxComputeConnection, self).__init__(conn_uri)
87+
user, pwd, endpoint, proj = MaxComputeConnection.get_uri_parts(
88+
conn_uri)
8589
self.driver = "maxcompute"
86-
self.params["database"] = self.params["curr_project"]
90+
self.params["database"] = proj
91+
self.endpoint = endpoint
92+
self._conn = ODPS(user, pwd, project=proj, endpoint=endpoint)
93+
94+
@staticmethod
95+
def get_uri_parts(uri):
96+
"""Get username, password, endpoint, projectfrom given uri
97+
98+
Args:
99+
uri: a vliad maxcompute connection uri
100+
101+
Returns:
102+
A tuple (username, password, endpoint, project)
103+
"""
104+
uripts = urlparse(uri)
105+
params = parse_qs(uripts.query)
87106
# compose an endpoint, only keep the host and path and replace scheme
88-
endpoint = self.uripts._replace(scheme=self.params["scheme"],
89-
query="",
90-
netloc=self.uripts.hostname)
91-
self.endpoint = endpoint.geturl()
92-
self._conn = ODPS(self.uripts.username,
93-
self.uripts.password,
94-
project=self.params["database"],
95-
endpoint=self.endpoint)
107+
endpoint = uripts._replace(scheme=params.get("scheme", ["http"])[0],
108+
query="",
109+
netloc=uripts.hostname)
110+
endpoint = endpoint.geturl()
111+
return (uripts.username, uripts.password, endpoint,
112+
params.get("curr_project", [""])[0])
96113

97114
def _get_result_set(self, statement):
98115
try:
@@ -108,3 +125,19 @@ def close(self):
108125
def get_table_schema(self, table_name):
109126
schema = self._conn.get_table(table_name).schema
110127
return [(c.name, str(c.type).upper()) for c in schema.columns]
128+
129+
def write_table(self,
130+
table_name,
131+
rows,
132+
compress_option=COMPRESS_ODPS_ZLIB):
133+
"""Append rows to given table, this is a driver specific api
134+
135+
Args:
136+
table_name: the table to write
137+
rows: list of rows, each row is a data tuple, like [(1,True,"ok"),(2,False,"bad")]
138+
compress_options: the compress options defined in
139+
tunnel.CompressOption.CompressAlgorithm
140+
"""
141+
self._conn.write_table(table_name,
142+
rows,
143+
compress_option=compress_option)

python/runtime/dbapi/maxcompute_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ def test_query(self):
5656

5757
def test_exec(self):
5858
conn = MaxComputeConnection(testing.get_datasource())
59-
rs = conn.exec(
59+
rs = conn.execute(
6060
"create table alifin_jtest_dev.sqlflow_test_exec(a int)")
6161
self.assertTrue(rs)
62-
rs = conn.exec(
62+
rs = conn.execute(
6363
"insert into alifin_jtest_dev.sqlflow_test_exec values(1), (2)")
6464
self.assertTrue(rs)
6565
rs = conn.query("select * from alifin_jtest_dev.sqlflow_test_exec")
6666
self.assertTrue(rs.success())
6767
rows = [r for r in rs]
6868
self.assertTrue(2, len(rows))
69-
rs = conn.exec("drop table alifin_jtest_dev.sqlflow_test_exec")
69+
rs = conn.execute("drop table alifin_jtest_dev.sqlflow_test_exec")
7070
self.assertTrue(rs)
7171

7272
def test_get_table_schema(self):

python/runtime/dbapi/mysql.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
# limitations under the License
1313

1414
import re
15-
from urllib.parse import ParseResult
15+
16+
from runtime.dbapi.connection import Connection, ResultSet
17+
from six.moves.urllib.parse import ParseResult
1618

1719
# NOTE: use MySQLdb to avoid bugs like infinite reading:
1820
# https://bugs.mysql.com/bug.php?id=91971
19-
from MySQLdb import connect
20-
from runtime.dbapi.connection import Connection, ResultSet
21+
try:
22+
from MySQLdb import connect
23+
except:
24+
pass
2125

2226
try:
2327
import MySQLdb.constants.FIELD_TYPE as MYSQL_FIELD_TYPE
@@ -40,7 +44,7 @@
4044

4145
class MySQLResultSet(ResultSet):
4246
def __init__(self, cursor, err=None):
43-
super().__init__()
47+
super(MySQLResultSet, self).__init__()
4448
self._cursor = cursor
4549
self._column_info = None
4650
self._err = err
@@ -55,10 +59,10 @@ def column_info(self):
5559
A list of column metas, like [(field_a, INT), (field_b, STRING)]
5660
"""
5761
if self._column_info is not None:
58-
return self.column_info
62+
return self._column_info
5963

6064
columns = []
61-
for desc in self._cursor.description:
65+
for desc in self._cursor.description or []:
6266
# NOTE: MySQL returns an integer number instead of a string
6367
# to represent the data type.
6468
typ = MYSQL_FIELD_TYPE_DICT.get(desc[1])
@@ -88,7 +92,7 @@ def close(self):
8892

8993
class MySQLConnection(Connection):
9094
def __init__(self, conn_uri):
91-
super().__init__(conn_uri)
95+
super(MySQLConnection, self).__init__(conn_uri)
9296
self.driver = "mysql"
9397
self.params["database"] = self.uripts.path.strip("/")
9498
self._conn = connect(user=self.uripts.username,
@@ -115,6 +119,16 @@ def _get_result_set(self, statement):
115119
cursor.close()
116120
return MySQLResultSet(None, str(e))
117121

122+
def cursor(self):
123+
"""Get a cursor on the connection
124+
We insist not to use the low level api like cursor.
125+
Instead, we can directly use query/exec
126+
"""
127+
return self._conn.cursor()
128+
129+
def commit(self):
130+
return self._conn.commit()
131+
118132
def close(self):
119133
if self._conn:
120134
self._conn.close()

python/runtime/dbapi/mysql_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,17 @@ def test_query(self):
5050

5151
def test_exec(self):
5252
conn = MySQLConnection(testing.get_datasource())
53-
rs = conn.exec("create table test_exec(a int)")
53+
rs = conn.execute("create table test_exec(a int)")
5454
self.assertTrue(rs)
55-
rs = conn.exec("insert into test_exec values(1), (2)")
55+
rs = conn.execute("insert into test_exec values(1), (2)")
5656
self.assertTrue(rs)
5757
rs = conn.query("select * from test_exec")
5858
self.assertTrue(rs.success())
5959
rows = [r for r in rs]
6060
self.assertTrue(2, len(rows))
61-
rs = conn.exec("drop table test_exec")
61+
rs = conn.execute("drop table test_exec")
6262
self.assertTrue(rs)
63-
rs = conn.exec("drop table not_exist")
63+
rs = conn.execute("drop table not_exist")
6464
self.assertFalse(rs)
6565

6666
def test_get_table_schema(self):

python/runtime/dbapi/paiio.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License
1313

14+
from __future__ import absolute_import
15+
1416
import re
1517

1618
from runtime.dbapi.connection import Connection, ResultSet
@@ -23,7 +25,7 @@
2325

2426
class PaiIOResultSet(ResultSet):
2527
def __init__(self, reader, err=None):
26-
super().__init__()
28+
super(PaiIOResultSet, self).__init__()
2729
self._reader = reader
2830
self._column_info = None
2931
self._err = err
@@ -41,10 +43,10 @@ def column_info(self):
4143
A list of column metas, like [(field_a, INT), (field_b, STRING)]
4244
"""
4345
if self._column_info is not None:
44-
return self.column_info
46+
return self._column_info
4547

4648
schema = self._reader.get_schema()
47-
columns = [(c['colname'], str.upper(c['typestr'])) for c in schema]
49+
columns = [(c['colname'], str(c['typestr']).upper()) for c in schema]
4850
self._column_info = columns
4951
return self._column_info
5052

@@ -72,24 +74,26 @@ class PaiIOConnection(Connection):
7274
"""PaiIOConnection emulate a connection for paiio,
7375
currently only support full-table reading. That means
7476
we can't filter the data, join the table and so on.
75-
The only supported query statement is `None`.
77+
The only supported query statement is `None`. The scheme
78+
part of the uri can be 'paiio' or 'odps'
7679
7780
Typical use is:
7881
con = PaiIOConnection("paiio://db/tables/my_table")
7982
res = con.query(None)
8083
rows = [r for r in res]
8184
"""
8285
def __init__(self, conn_uri):
83-
super().__init__(conn_uri)
86+
super(PaiIOConnection, self).__init__(conn_uri)
87+
# (TODO: lhw) change driver to paiio
8488
self.driver = "pai_maxcompute"
85-
match = re.findall(r"paiio://\w+/tables/(.+)", self.uripts.path)
89+
match = re.findall(r"\w+://\w+/tables/(.+)", conn_uri)
8690
if len(match) < 1:
8791
raise ValueError("Should specify table in uri with format: "
8892
"paiio://db/tables/table?param_a=a&param_b=b")
89-
self.params["database"] = self.uripts.hostname
90-
self.params["table"] = match[0]
93+
self.params["table"] = conn_uri.replace("paiio://", "odps://")
9194
self.params["slice_id"] = self.params.get("slice_id", 0)
9295
self.params["slice_count"] = self.params.get("slice_count", 1)
96+
print(self.params)
9397

9498
def _get_result_set(self, statement):
9599
if statement is not None:
@@ -101,6 +105,7 @@ def _get_result_set(self, statement):
101105
slice_count=self.params["slice_count"])
102106
return PaiIOResultSet(reader, None)
103107
except Exception as e:
108+
print(e.args)
104109
return PaiIOResultSet(None, str(e))
105110

106111
def get_table_schema(self, full_uri):
@@ -110,7 +115,7 @@ def get_table_schema(self, full_uri):
110115
return PaiIOConnection.get_schema(full_uri)
111116

112117
def query(self, statement=None):
113-
return super().query(statement)
118+
return super(PaiIOConnection, self).query(statement)
114119

115120
@staticmethod
116121
def get_table_row_num(table_uri):

0 commit comments

Comments
 (0)