Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ __pycache__/
*.log

.DS_Store
.vscode
60 changes: 0 additions & 60 deletions proto/sqlflow.proto

This file was deleted.

8 changes: 7 additions & 1 deletion proto/sqlflow/proto/sqlflow.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@ service SQLFlow {
rpc Run (Request) returns (stream Response);
}

message Session {
string token = 1;
string db_conn_str = 2;
}

// SQL statements to run
// e.g.
// 1. `SELECT ...`
// 2. `USE ...`, `DELETE ...`
// 3. `SELECT ... TRAIN/PREDICT ...`
message Request {
string sql = 1; // The SQL statement to be executed.
string sql = 1; // The SQL statement to be executed.
Session session = 2; // The Session struct including user credentical message.
}

message Response {
Expand Down
12 changes: 9 additions & 3 deletions sqlflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def __init__(self, server_url=None, ca_crt=None):
raise ValueError("Can't find environment variable SQLFLOW_SERVER")
server_url = os.environ["SQLFLOW_SERVER"]

self._stub = pb_grpc.SQLFlowStub(self.newRPCChannel(server_url, ca_crt))
self._stub = pb_grpc.SQLFlowStub(self.new_rpc_channel(server_url, ca_crt))

def newRPCChannel(self, server_url, ca_crt):
def new_rpc_channel(self, server_url, ca_crt):
if ca_crt is None and "SQLFLOW_CA_CRT" not in os.environ:
# client would connect SQLFLow gRPC server with insecure mode.
channel = grpc.insecure_channel(server_url)
Expand All @@ -106,6 +106,12 @@ def newRPCChannel(self, server_url, ca_crt):
channel = grpc.secure_channel(server_url, creds)
return channel

def sql_request(self, sql):
token = os.getenv("SQLFLOW_USER_TOKEN", "")
db_conn_str = os.getenv("SQLFLOW_DATASOURCE", "")
se = pb.Session(token=token, db_conn_str=db_conn_str)
return pb.Request(sql=sql, session=se)

def execute(self, operation):
"""Run a SQL statement

Expand All @@ -120,7 +126,7 @@ def execute(self, operation):

"""
try:
stream_response = self._stub.Run(pb.Request(sql=operation))
stream_response = self._stub.Run(self.sql_request(operation))
return self.display(stream_response)
except grpc.RpcError as e:
_LOGGER.error("%s\n%s", e.code(), e.details())
Expand Down
3 changes: 3 additions & 0 deletions tests/mock_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def Run(self, request, context):
else:
for res in MockServicer.table_response(MockServicer.get_test_table()):
yield res
elif SQL == "TEST VERIFY SESSION":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add comments why need this string, and how we improve this in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a todo comment here, will improve it with an elegant way in the future.

# TODO(Yancey1989): using a elegant way to test the session instead of the trick.
yield MockServicer.message_response("|".join([request.session.token, request.session.db_conn_str]))
else:
yield MockServicer.message_response('bad request', 0)

Expand Down
9 changes: 9 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,12 @@ def test_decode_null(self):
null_message = pb.Row.Null()
any_message.Pack(null_message)
assert Client._decode_any(any_message) is None

def test_session(self):
token = "unittest-user"
ds = "maxcompute://AK:SK@host:port"
os.environ["SQLFLOW_USER_TOKEN"] = token
os.environ["SQLFLOW_DATASOURCE"] = ds
with mock.patch('sqlflow.client._LOGGER') as log_mock:
self.client.execute("TEST VERIFY SESSION")
log_mock.info.assert_called_with("|".join([token, ds]))