Skip to content

Commit 5fdff8e

Browse files
authored
Adapt paiio with DB-API (#2809)
* Add maxcompute DB-API * remove unused import * format code * polish db-api * Adapt paiio with DB-API * Adapt paiio with DB-API * add try import paiio * fix typo
1 parent 9d0af16 commit 5fdff8e

File tree

2 files changed

+152
-4
lines changed

2 files changed

+152
-4
lines changed

python/runtime/dbapi/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
from runtime.dbapi.hive import HiveConnection
1515
from runtime.dbapi.maxcompute import MaxComputeConnection
1616
from runtime.dbapi.mysql import MySQLConnection
17+
from runtime.dbapi.paiio import PaiIOConnection
1718

18-
DRIVRE_MAP = {
19+
DRIVER_MAP = {
1920
"mysql": MySQLConnection,
2021
"hive": HiveConnection,
21-
"maxcompute": MaxComputeConnection
22+
"maxcompute": MaxComputeConnection,
23+
"paiio": PaiIOConnection
2224
}
2325

2426

@@ -37,6 +39,6 @@ def connect(uri):
3739
parts = uri.split("://")
3840
if len(parts) < 2:
3941
raise ValueError("Input should be a valid uri.")
40-
if parts[0] not in DRIVRE_MAP:
42+
if parts[0] not in DRIVER_MAP:
4143
raise ValueError("Can't find driver for scheme: %s" % parts[0])
42-
return DRIVRE_MAP[parts[0]](uri)
44+
return DRIVER_MAP[parts[0]](uri)

python/runtime/dbapi/paiio.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License
13+
14+
import re
15+
16+
from runtime.dbapi.connection import Connection, ResultSet
17+
18+
try:
19+
import paiio
20+
except Exception: # noqa: E722
21+
pass
22+
23+
24+
class PaiIOResultSet(ResultSet):
25+
def __init__(self, reader, err=None):
26+
super().__init__()
27+
self._reader = reader
28+
self._column_info = None
29+
self._err = err
30+
31+
def _fetch(self, fetch_size):
32+
try:
33+
return self._reader.read(num_records=fetch_size)
34+
except Exception: # noqa: E722
35+
pass
36+
37+
def column_info(self):
38+
"""Get the result column meta, type in the meta maybe DB-specific
39+
40+
Returns:
41+
A list of column metas, like [(field_a, INT), (field_b, STRING)]
42+
"""
43+
if self._column_info is not None:
44+
return self.column_info
45+
46+
schema = self._reader.get_schema()
47+
columns = [(c['colname'], str.upper(c['typestr'])) for c in schema]
48+
self._column_info = columns
49+
return self._column_info
50+
51+
def success(self):
52+
"""Return True if the query is success"""
53+
return self._reader is not None
54+
55+
def error(self):
56+
return self._err
57+
58+
def close(self):
59+
"""
60+
Close the ResultSet explicitly, release any resource incurred
61+
by this query
62+
"""
63+
if self._reader:
64+
self._reader.close()
65+
self._reader = None
66+
67+
def __del__(self):
68+
self.close()
69+
70+
71+
class PaiIOConnection(Connection):
72+
"""PaiIOConnection emulate a connection for paiio,
73+
currently only support full-table reading. That means
74+
we can't filter the data, join the table and so on.
75+
The only supported query statement is `None`.
76+
77+
Typical use is:
78+
con = PaiIOConnection("paiio://db/tables/my_table")
79+
res = con.query(None)
80+
rows = [r for r in res]
81+
"""
82+
def __init__(self, conn_uri):
83+
super().__init__(conn_uri)
84+
self.driver = "pai_maxcompute"
85+
match = re.findall(r"paiio://\w+/tables/(.+)", self.uripts.path)
86+
if len(match) < 1:
87+
raise ValueError("Should specify table in uri with format: "
88+
"paiio://db/tables/table?param_a=a&param_b=b")
89+
self.params["database"] = self.uripts.hostname
90+
self.params["table"] = match[0]
91+
self.params["slice_id"] = self.params.get("slice_id", 0)
92+
self.params["slice_count"] = self.params.get("slice_count", 1)
93+
94+
def _get_result_set(self, statement):
95+
if statement is not None:
96+
raise ValueError("paiio only support full table read,"
97+
"so you need to pass statement with None.")
98+
try:
99+
reader = paiio.TableReader(self.params["table"],
100+
slice_id=self.params["slice_id"],
101+
slice_count=self.params["slice_count"])
102+
return PaiIOResultSet(reader, None)
103+
except Exception as e:
104+
return PaiIOResultSet(None, str(e))
105+
106+
def get_table_schema(self, full_uri):
107+
"""Get schema of given table, caller need to supply the full
108+
uri for paiio table, this is slight different with other connections.
109+
"""
110+
return PaiIOConnection.get_schema(full_uri)
111+
112+
def query(self, statement=None):
113+
return super().query(statement)
114+
115+
@staticmethod
116+
def get_table_row_num(table_uri):
117+
"""Get row number of given table
118+
119+
Args:
120+
table_uri: the full uri for the table to get row from
121+
122+
Return:
123+
Number of rows in the table
124+
"""
125+
reader = paiio.TableReader(table_uri)
126+
row_num = reader.get_row_count()
127+
reader.close()
128+
return row_num
129+
130+
@staticmethod
131+
def get_schema(table_uri):
132+
"""Get schema of the given table
133+
134+
Args:
135+
table_uri: the full uri for the table to get row from
136+
137+
Returns:
138+
A list of column metas, like [(field_a, INT), (field_b, STRING)]
139+
"""
140+
rs = PaiIOConnection(table_uri).query()
141+
col_info = rs.column_info()
142+
rs.close()
143+
return col_info
144+
145+
def close(self):
146+
pass

0 commit comments

Comments
 (0)