Skip to content

Commit fe9da69

Browse files
authored
Add maxcompute DB-API (#2801)
* Add maxcompute DB-API * remove unused import * format code
1 parent ee93f21 commit fe9da69

File tree

2 files changed

+180
-0
lines changed

2 files changed

+180
-0
lines changed

python/runtime/dbapi/maxcompute.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License
13+
14+
from odps import ODPS, tunnel
15+
from runtime.dbapi.connection import Connection, ResultSet
16+
17+
18+
class MaxComputeResultSet(ResultSet):
19+
"""MaxCompute query result"""
20+
def __init__(self, instance, err=None):
21+
super().__init__()
22+
self._instance = instance
23+
self._column_info = None
24+
self._err = err
25+
self._reader = None
26+
self._read_count = 0
27+
28+
def _fetch(self, fetch_size):
29+
r = self._open_reader()
30+
count = min(fetch_size, r.count - self._read_count)
31+
rows = [[f[1] for f in row]
32+
for row in r[self._read_count:self._read_count + count]]
33+
self._read_count += count
34+
return rows
35+
36+
def column_info(self):
37+
"""Get the result column meta, type in the meta maybe DB-specific
38+
39+
Returns:
40+
A list of column metas, like [(field_a, INT), (field_b, STRING)]
41+
"""
42+
if self._column_info is not None:
43+
return self.column_info
44+
45+
r = self._open_reader()
46+
self._column_info = [(col.name, col.type) for col in r._schema.columns]
47+
return self._column_info
48+
49+
def _open_reader(self):
50+
if not self._reader:
51+
compress = tunnel.CompressOption.CompressAlgorithm.ODPS_ZLIB
52+
self._reader = self._instance.open_reader(tunnel=True,
53+
compress_option=compress)
54+
return self._reader
55+
56+
def success(self):
57+
"""Return True if the query is success"""
58+
return self._instance is not None and self._instance.is_successful()
59+
60+
def error(self):
61+
return self._err
62+
63+
def close(self):
64+
if self._reader:
65+
if hasattr(self._reader, "close"):
66+
self._reader.close()
67+
self._reader = None
68+
self._instance = None
69+
70+
def __del__(self):
71+
self.close()
72+
73+
74+
class MaxComputeConnection(Connection):
75+
"""MaxCompute connection, this class uses ODPS object to establish
76+
connection with maxcompute
77+
78+
Args:
79+
conn_uri: uri in format:
80+
maxcompute://access_id:access_key@service.com/api?curr_project=test_ci&scheme=http
81+
"""
82+
def __init__(self, conn_uri):
83+
super().__init__(conn_uri)
84+
self.params["database"] = self.params["curr_project"]
85+
# compose an endpoint, only keep the host and path and replace scheme
86+
endpoint = self.uripts._replace(scheme=self.params["scheme"],
87+
query="",
88+
netloc=self.uripts.hostname)
89+
self._conn = ODPS(self.uripts.username,
90+
self.uripts.password,
91+
project=self.params["database"],
92+
endpoint=endpoint.geturl())
93+
94+
def _parse_uri(self):
95+
return super()._parse_uri()
96+
97+
def _get_result_set(self, statement):
98+
try:
99+
instance = self._conn.execute_sql(statement)
100+
return MaxComputeResultSet(instance)
101+
except Exception as e:
102+
return MaxComputeResultSet(None, str(e))
103+
104+
def close(self):
105+
if self._conn:
106+
self._conn = None
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License
13+
14+
import unittest
15+
from unittest import TestCase
16+
17+
from runtime import testing
18+
from runtime.dbapi.maxcompute import MaxComputeConnection
19+
20+
21+
@unittest.skipUnless(testing.get_driver() == "maxcompute",
22+
"Skip non-maxcompute test")
23+
class TestMaxComputeConnection(TestCase):
24+
def test_connecion(self):
25+
try:
26+
conn = MaxComputeConnection(testing.get_datasource())
27+
conn.close()
28+
except: # noqa: E722
29+
self.fail()
30+
31+
def test_query(self):
32+
conn = MaxComputeConnection(testing.get_datasource())
33+
rs = conn.query("select * from notexist limit 1")
34+
self.assertFalse(rs.success())
35+
self.assertTrue("Table not found" in rs.error())
36+
37+
rs = conn.query(
38+
"select * from alifin_jtest_dev.sqlflow_iris_train limit 1")
39+
self.assertTrue(rs.success())
40+
rows = [r for r in rs]
41+
self.assertEqual(1, len(rows))
42+
43+
rs = conn.query(
44+
"select * from alifin_jtest_dev.sqlflow_iris_train limit 20")
45+
self.assertTrue(rs.success())
46+
47+
col_info = rs.column_info()
48+
self.assertEqual([('sepal_length', 'double'),
49+
('sepal_width', 'double'),
50+
('petal_length', 'double'),
51+
('petal_width', 'double'), ('class', 'bigint')],
52+
col_info)
53+
54+
rows = [r for r in rs]
55+
self.assertTrue(20, len(rows))
56+
57+
def test_exec(self):
58+
conn = MaxComputeConnection(testing.get_datasource())
59+
rs = conn.exec(
60+
"create table alifin_jtest_dev.sqlflow_test_exec(a int)")
61+
self.assertTrue(rs)
62+
rs = conn.exec(
63+
"insert into alifin_jtest_dev.sqlflow_test_exec values(1), (2)")
64+
self.assertTrue(rs)
65+
rs = conn.query("select * from alifin_jtest_dev.sqlflow_test_exec")
66+
self.assertTrue(rs.success())
67+
rows = [r for r in rs]
68+
self.assertTrue(2, len(rows))
69+
rs = conn.exec("drop table alifin_jtest_dev.sqlflow_test_exec")
70+
self.assertTrue(rs)
71+
72+
73+
if __name__ == "__main__":
74+
unittest.main()

0 commit comments

Comments
 (0)