Skip to content

Commit 4244e16

Browse files
authored
add session struct in proto (#71)
* add session struct in proto * add todo comment
1 parent 6a42bd1 commit 4244e16

File tree

6 files changed

+29
-64
lines changed

6 files changed

+29
-64
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ __pycache__/
1717
*.log
1818

1919
.DS_Store
20+
.vscode

proto/sqlflow.proto

Lines changed: 0 additions & 60 deletions
This file was deleted.

proto/sqlflow/proto/sqlflow.proto

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@ service SQLFlow {
1919
rpc Run (Request) returns (stream Response);
2020
}
2121

22+
message Session {
23+
string token = 1;
24+
string db_conn_str = 2;
25+
}
26+
2227
// SQL statements to run
2328
// e.g.
2429
// 1. `SELECT ...`
2530
// 2. `USE ...`, `DELETE ...`
2631
// 3. `SELECT ... TRAIN/PREDICT ...`
2732
message Request {
28-
string sql = 1; // The SQL statement to be executed.
33+
string sql = 1; // The SQL statement to be executed.
34+
Session session = 2; // The Session struct including user credentical message.
2935
}
3036

3137
message Response {

sqlflow/client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def __init__(self, server_url=None, ca_crt=None):
9292
raise ValueError("Can't find environment variable SQLFLOW_SERVER")
9393
server_url = os.environ["SQLFLOW_SERVER"]
9494

95-
self._stub = pb_grpc.SQLFlowStub(self.newRPCChannel(server_url, ca_crt))
95+
self._stub = pb_grpc.SQLFlowStub(self.new_rpc_channel(server_url, ca_crt))
9696

97-
def newRPCChannel(self, server_url, ca_crt):
97+
def new_rpc_channel(self, server_url, ca_crt):
9898
if ca_crt is None and "SQLFLOW_CA_CRT" not in os.environ:
9999
# client would connect SQLFLow gRPC server with insecure mode.
100100
channel = grpc.insecure_channel(server_url)
@@ -106,6 +106,12 @@ def newRPCChannel(self, server_url, ca_crt):
106106
channel = grpc.secure_channel(server_url, creds)
107107
return channel
108108

109+
def sql_request(self, sql):
110+
token = os.getenv("SQLFLOW_USER_TOKEN", "")
111+
db_conn_str = os.getenv("SQLFLOW_DATASOURCE", "")
112+
se = pb.Session(token=token, db_conn_str=db_conn_str)
113+
return pb.Request(sql=sql, session=se)
114+
109115
def execute(self, operation):
110116
"""Run a SQL statement
111117
@@ -120,7 +126,7 @@ def execute(self, operation):
120126
121127
"""
122128
try:
123-
stream_response = self._stub.Run(pb.Request(sql=operation))
129+
stream_response = self._stub.Run(self.sql_request(operation))
124130
return self.display(stream_response)
125131
except grpc.RpcError as e:
126132
_LOGGER.error("%s\n%s", e.code(), e.details())

tests/mock_servicer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def Run(self, request, context):
2020
else:
2121
for res in MockServicer.table_response(MockServicer.get_test_table()):
2222
yield res
23+
elif SQL == "TEST VERIFY SESSION":
24+
# TODO(Yancey1989): using a elegant way to test the session instead of the trick.
25+
yield MockServicer.message_response("|".join([request.session.token, request.session.db_conn_str]))
2326
else:
2427
yield MockServicer.message_response('bad request', 0)
2528

tests/test_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,12 @@ def test_decode_null(self):
6767
null_message = pb.Row.Null()
6868
any_message.Pack(null_message)
6969
assert Client._decode_any(any_message) is None
70+
71+
def test_session(self):
72+
token = "unittest-user"
73+
ds = "maxcompute://AK:SK@host:port"
74+
os.environ["SQLFLOW_USER_TOKEN"] = token
75+
os.environ["SQLFLOW_DATASOURCE"] = ds
76+
with mock.patch('sqlflow.client._LOGGER') as log_mock:
77+
self.client.execute("TEST VERIFY SESSION")
78+
log_mock.info.assert_called_with("|".join([token, ds]))

0 commit comments

Comments
 (0)