Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/runtime/dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# limitations under the License

from abc import ABCMeta, abstractmethod
from urllib.parse import parse_qs, urlparse

import six
from six.moves.urllib.parse import parse_qs, urlparse


@six.add_metaclass(ABCMeta)
Expand Down Expand Up @@ -102,6 +102,11 @@ def __init__(self, conn_uri):
if len(l) == 1:
self.params[k] = l[0]

def param(self, param_name, default_value=""):
if not self.params:
return default_value
return self.params.get(param_name, default_value)

def _parse_uri(self):
"""Parse the connection string into URI parts
Returns:
Expand Down Expand Up @@ -139,7 +144,7 @@ def query(self, statement):
"""
return self._get_result_set(statement)

def exec(self, statement):
def execute(self, statement):
"""Execute given statement and return True on success

Args:
Expand Down
21 changes: 17 additions & 4 deletions python/runtime/dbapi/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License

from impala.dbapi import connect
try:
from impala.dbapi import connect
except: # noqa E722
pass
from runtime.dbapi.connection import Connection, ResultSet


class HiveResultSet(ResultSet):
def __init__(self, cursor, err=None):
super().__init__()
super(HiveResultSet, self).__init__()
self._cursor = cursor
self._column_info = None
self._err = err
Expand All @@ -33,7 +36,7 @@ def column_info(self):
"""

if self._column_info is not None:
return self.column_info
return self._column_info

columns = []
for desc in self._cursor.description:
Expand Down Expand Up @@ -66,7 +69,7 @@ class HiveConnection(Connection):
configuration
"""
def __init__(self, conn_uri):
super().__init__(conn_uri)
super(HiveConnection, self).__init__(conn_uri)
self.driver = "hive"
self.params["database"] = self.uripts.path.strip("/")
self._conn = connect(user=self.uripts.username,
Expand All @@ -87,6 +90,16 @@ def _get_result_set(self, statement):
cursor.close()
return HiveResultSet(None, str(e))

def cursor(self):
"""Get a cursor on the connection
We insist not to use the low level api like cursor.
Instead, we can directly use query/exec
"""
return self._conn.cursor()

def commit(self):
return self._conn.commit()

def close(self):
if self._conn:
self._conn.close()
Expand Down
6 changes: 3 additions & 3 deletions python/runtime/dbapi/hive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def test_query(self):

def test_exec(self):
conn = HiveConnection(testing.get_datasource())
rs = conn.exec("create table test_exec(a int)")
rs = conn.execute("create table test_exec(a int)")
self.assertTrue(rs)
rs = conn.exec("insert into test_exec values(1), (2)")
rs = conn.execute("insert into test_exec values(1), (2)")
self.assertTrue(rs)
rs = conn.query("select * from test_exec")
self.assertTrue(rs.success())
rows = [r for r in rs]
self.assertTrue(2, len(rows))
rs = conn.exec("drop table test_exec")
rs = conn.execute("drop table test_exec")
self.assertTrue(rs)

def test_get_table_schema(self):
Expand Down
66 changes: 50 additions & 16 deletions python/runtime/dbapi/maxcompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@

from odps import ODPS, tunnel
from runtime.dbapi.connection import Connection, ResultSet
from six.moves.urllib.parse import parse_qs, urlparse

COMPRESS_ODPS_ZLIB = tunnel.CompressOption.CompressAlgorithm.ODPS_ZLIB


class MaxComputeResultSet(ResultSet):
"""MaxCompute query result"""
def __init__(self, instance, err=None):
super().__init__()
super(MaxComputeResultSet, self).__init__()
self._instance = instance
self._column_info = None
self._err = err
Expand All @@ -40,18 +43,17 @@ def column_info(self):
A list of column metas, like [(field_a, INT), (field_b, STRING)]
"""
if self._column_info is not None:
return self.column_info
return self._column_info

r = self._open_reader()
self._column_info = [(col.name, str.upper(col.type))
self._column_info = [(col.name, str(col.type).upper())
for col in r._schema.columns]
return self._column_info

def _open_reader(self):
if not self._reader:
compress = tunnel.CompressOption.CompressAlgorithm.ODPS_ZLIB
self._reader = self._instance.open_reader(tunnel=True,
compress_option=compress)
self._reader = self._instance.open_reader(
tunnel=True, compress_option=COMPRESS_ODPS_ZLIB)
return self._reader

def success(self):
Expand Down Expand Up @@ -81,18 +83,33 @@ class MaxComputeConnection(Connection):
maxcompute://access_id:access_key@service.com/api?curr_project=test_ci&scheme=http
"""
def __init__(self, conn_uri):
super().__init__(conn_uri)
super(MaxComputeConnection, self).__init__(conn_uri)
user, pwd, endpoint, proj = MaxComputeConnection.get_uri_parts(
conn_uri)
self.driver = "maxcompute"
self.params["database"] = self.params["curr_project"]
self.params["database"] = proj
self.endpoint = endpoint
self._conn = ODPS(user, pwd, project=proj, endpoint=endpoint)

@staticmethod
def get_uri_parts(uri):
"""Get username, password, endpoint, projectfrom given uri

Args:
uri: a valid maxcompute connection uri

Returns:
A tuple (username, password, endpoint, project)
"""
uripts = urlparse(uri)
params = parse_qs(uripts.query)
# compose an endpoint, only keep the host and path and replace scheme
endpoint = self.uripts._replace(scheme=self.params["scheme"],
query="",
netloc=self.uripts.hostname)
self.endpoint = endpoint.geturl()
self._conn = ODPS(self.uripts.username,
self.uripts.password,
project=self.params["database"],
endpoint=self.endpoint)
endpoint = uripts._replace(scheme=params.get("scheme", ["http"])[0],
query="",
netloc=uripts.hostname)
endpoint = endpoint.geturl()
return (uripts.username, uripts.password, endpoint,
params.get("curr_project", [""])[0])

def _get_result_set(self, statement):
try:
Expand All @@ -108,3 +125,20 @@ def close(self):
def get_table_schema(self, table_name):
schema = self._conn.get_table(table_name).schema
return [(c.name, str(c.type).upper()) for c in schema.columns]

def write_table(self,
table_name,
rows,
compress_option=COMPRESS_ODPS_ZLIB):
"""Append rows to given table, this is a driver specific api

Args:
table_name: the table to write
rows: list of rows, each row is a data tuple,
like [(1,True,"ok"),(2,False,"bad")]
compress_options: the compress options defined in
tunnel.CompressOption.CompressAlgorithm
"""
self._conn.write_table(table_name,
rows,
compress_option=compress_option)
6 changes: 3 additions & 3 deletions python/runtime/dbapi/maxcompute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ def test_query(self):

def test_exec(self):
conn = MaxComputeConnection(testing.get_datasource())
rs = conn.exec(
rs = conn.execute(
"create table alifin_jtest_dev.sqlflow_test_exec(a int)")
self.assertTrue(rs)
rs = conn.exec(
rs = conn.execute(
"insert into alifin_jtest_dev.sqlflow_test_exec values(1), (2)")
self.assertTrue(rs)
rs = conn.query("select * from alifin_jtest_dev.sqlflow_test_exec")
self.assertTrue(rs.success())
rows = [r for r in rs]
self.assertTrue(2, len(rows))
rs = conn.exec("drop table alifin_jtest_dev.sqlflow_test_exec")
rs = conn.execute("drop table alifin_jtest_dev.sqlflow_test_exec")
self.assertTrue(rs)

def test_get_table_schema(self):
Expand Down
28 changes: 21 additions & 7 deletions python/runtime/dbapi/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# limitations under the License

import re
from urllib.parse import ParseResult

from runtime.dbapi.connection import Connection, ResultSet
from six.moves.urllib.parse import ParseResult

# NOTE: use MySQLdb to avoid bugs like infinite reading:
# https://bugs.mysql.com/bug.php?id=91971
from MySQLdb import connect
from runtime.dbapi.connection import Connection, ResultSet
try:
from MySQLdb import connect
except: # noqa E722
pass

try:
import MySQLdb.constants.FIELD_TYPE as MYSQL_FIELD_TYPE
Expand All @@ -40,7 +44,7 @@

class MySQLResultSet(ResultSet):
def __init__(self, cursor, err=None):
super().__init__()
super(MySQLResultSet, self).__init__()
self._cursor = cursor
self._column_info = None
self._err = err
Expand All @@ -55,10 +59,10 @@ def column_info(self):
A list of column metas, like [(field_a, INT), (field_b, STRING)]
"""
if self._column_info is not None:
return self.column_info
return self._column_info

columns = []
for desc in self._cursor.description:
for desc in self._cursor.description or []:
# NOTE: MySQL returns an integer number instead of a string
# to represent the data type.
typ = MYSQL_FIELD_TYPE_DICT.get(desc[1])
Expand Down Expand Up @@ -88,7 +92,7 @@ def close(self):

class MySQLConnection(Connection):
def __init__(self, conn_uri):
super().__init__(conn_uri)
super(MySQLConnection, self).__init__(conn_uri)
self.driver = "mysql"
self.params["database"] = self.uripts.path.strip("/")
self._conn = connect(user=self.uripts.username,
Expand All @@ -115,6 +119,16 @@ def _get_result_set(self, statement):
cursor.close()
return MySQLResultSet(None, str(e))

def cursor(self):
"""Get a cursor on the connection
We insist not to use the low level api like cursor.
Instead, we can directly use query/exec
"""
return self._conn.cursor()

def commit(self):
return self._conn.commit()

def close(self):
if self._conn:
self._conn.close()
Expand Down
8 changes: 4 additions & 4 deletions python/runtime/dbapi/mysql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ def test_query(self):

def test_exec(self):
conn = MySQLConnection(testing.get_datasource())
rs = conn.exec("create table test_exec(a int)")
rs = conn.execute("create table test_exec(a int)")
self.assertTrue(rs)
rs = conn.exec("insert into test_exec values(1), (2)")
rs = conn.execute("insert into test_exec values(1), (2)")
self.assertTrue(rs)
rs = conn.query("select * from test_exec")
self.assertTrue(rs.success())
rows = [r for r in rs]
self.assertTrue(2, len(rows))
rs = conn.exec("drop table test_exec")
rs = conn.execute("drop table test_exec")
self.assertTrue(rs)
rs = conn.exec("drop table not_exist")
rs = conn.execute("drop table not_exist")
self.assertFalse(rs)

def test_get_table_schema(self):
Expand Down
21 changes: 12 additions & 9 deletions python/runtime/dbapi/paiio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License

from __future__ import absolute_import

import re

from runtime.dbapi.connection import Connection, ResultSet
Expand All @@ -23,7 +25,7 @@

class PaiIOResultSet(ResultSet):
def __init__(self, reader, err=None):
super().__init__()
super(PaiIOResultSet, self).__init__()
self._reader = reader
self._column_info = None
self._err = err
Expand All @@ -41,10 +43,10 @@ def column_info(self):
A list of column metas, like [(field_a, INT), (field_b, STRING)]
"""
if self._column_info is not None:
return self.column_info
return self._column_info

schema = self._reader.get_schema()
columns = [(c['colname'], str.upper(c['typestr'])) for c in schema]
columns = [(c['colname'], str(c['typestr']).upper()) for c in schema]
self._column_info = columns
return self._column_info

Expand Down Expand Up @@ -72,22 +74,23 @@ class PaiIOConnection(Connection):
"""PaiIOConnection emulate a connection for paiio,
currently only support full-table reading. That means
we can't filter the data, join the table and so on.
The only supported query statement is `None`.
The only supported query statement is `None`. The scheme
part of the uri can be 'paiio' or 'odps'

Typical use is:
con = PaiIOConnection("paiio://db/tables/my_table")
res = con.query(None)
rows = [r for r in res]
"""
def __init__(self, conn_uri):
super().__init__(conn_uri)
super(PaiIOConnection, self).__init__(conn_uri)
# (TODO: lhw) change driver to paiio
self.driver = "pai_maxcompute"
match = re.findall(r"paiio://\w+/tables/(.+)", self.uripts.path)
match = re.findall(r"\w+://\w+/tables/(.+)", conn_uri)
if len(match) < 1:
raise ValueError("Should specify table in uri with format: "
"paiio://db/tables/table?param_a=a&param_b=b")
self.params["database"] = self.uripts.hostname
self.params["table"] = match[0]
self.params["table"] = conn_uri.replace("paiio://", "odps://")
self.params["slice_id"] = self.params.get("slice_id", 0)
self.params["slice_count"] = self.params.get("slice_count", 1)

Expand All @@ -110,7 +113,7 @@ def get_table_schema(self, full_uri):
return PaiIOConnection.get_schema(full_uri)

def query(self, statement=None):
return super().query(statement)
return super(PaiIOConnection, self).query(statement)

@staticmethod
def get_table_row_num(table_uri):
Expand Down
1 change: 1 addition & 0 deletions scripts/test/hive.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ go generate ./...
go install ./...
gotest -p 1 -covermode=count -coverprofile=coverage.txt -timeout 1800s -v ./...
python -m unittest discover -v python "db_test.py"
python -m unittest discover -v python "hive_test.py"