|
| 1 | +# Copyright 2020 The SQLFlow Authors. All rights reserved. |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# |
| 6 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Unless required by applicable law or agreed to in writing, software |
| 9 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +# See the License for the specific language governing permissions and |
| 12 | +# limitations under the License |
| 13 | + |
| 14 | +from odps import ODPS, tunnel |
| 15 | +from runtime.dbapi.connection import Connection, ResultSet |
| 16 | + |
| 17 | + |
| 18 | +class MaxComputeResultSet(ResultSet): |
| 19 | + """MaxCompute query result""" |
| 20 | + def __init__(self, instance, err=None): |
| 21 | + super().__init__() |
| 22 | + self._instance = instance |
| 23 | + self._column_info = None |
| 24 | + self._err = err |
| 25 | + self._reader = None |
| 26 | + self._read_count = 0 |
| 27 | + |
| 28 | + def _fetch(self, fetch_size): |
| 29 | + r = self._open_reader() |
| 30 | + count = min(fetch_size, r.count - self._read_count) |
| 31 | + rows = [[f[1] for f in row] |
| 32 | + for row in r[self._read_count:self._read_count + count]] |
| 33 | + self._read_count += count |
| 34 | + return rows |
| 35 | + |
| 36 | + def column_info(self): |
| 37 | + """Get the result column meta, type in the meta maybe DB-specific |
| 38 | +
|
| 39 | + Returns: |
| 40 | + A list of column metas, like [(field_a, INT), (field_b, STRING)] |
| 41 | + """ |
| 42 | + if self._column_info is not None: |
| 43 | + return self.column_info |
| 44 | + |
| 45 | + r = self._open_reader() |
| 46 | + self._column_info = [(col.name, col.type) for col in r._schema.columns] |
| 47 | + return self._column_info |
| 48 | + |
| 49 | + def _open_reader(self): |
| 50 | + if not self._reader: |
| 51 | + compress = tunnel.CompressOption.CompressAlgorithm.ODPS_ZLIB |
| 52 | + self._reader = self._instance.open_reader(tunnel=True, |
| 53 | + compress_option=compress) |
| 54 | + return self._reader |
| 55 | + |
| 56 | + def success(self): |
| 57 | + """Return True if the query is success""" |
| 58 | + return self._instance is not None and self._instance.is_successful() |
| 59 | + |
| 60 | + def error(self): |
| 61 | + return self._err |
| 62 | + |
| 63 | + def close(self): |
| 64 | + if self._reader: |
| 65 | + if hasattr(self._reader, "close"): |
| 66 | + self._reader.close() |
| 67 | + self._reader = None |
| 68 | + self._instance = None |
| 69 | + |
| 70 | + def __del__(self): |
| 71 | + self.close() |
| 72 | + |
| 73 | + |
| 74 | +class MaxComputeConnection(Connection): |
| 75 | + """MaxCompute connection, this class uses ODPS object to establish |
| 76 | + connection with maxcompute |
| 77 | +
|
| 78 | + Args: |
| 79 | + conn_uri: uri in format: |
| 80 | + maxcompute://access_id:access_key@service.com/api?curr_project=test_ci&scheme=http |
| 81 | + """ |
| 82 | + def __init__(self, conn_uri): |
| 83 | + super().__init__(conn_uri) |
| 84 | + self.params["database"] = self.params["curr_project"] |
| 85 | + # compose an endpoint, only keep the host and path and replace scheme |
| 86 | + endpoint = self.uripts._replace(scheme=self.params["scheme"], |
| 87 | + query="", |
| 88 | + netloc=self.uripts.hostname) |
| 89 | + self._conn = ODPS(self.uripts.username, |
| 90 | + self.uripts.password, |
| 91 | + project=self.params["database"], |
| 92 | + endpoint=endpoint.geturl()) |
| 93 | + |
| 94 | + def _parse_uri(self): |
| 95 | + return super()._parse_uri() |
| 96 | + |
| 97 | + def _get_result_set(self, statement): |
| 98 | + try: |
| 99 | + instance = self._conn.execute_sql(statement) |
| 100 | + return MaxComputeResultSet(instance) |
| 101 | + except Exception as e: |
| 102 | + return MaxComputeResultSet(None, str(e)) |
| 103 | + |
| 104 | + def close(self): |
| 105 | + if self._conn: |
| 106 | + self._conn = None |
0 commit comments