Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions skyflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ class InfoMessages(Enum):
INSERT_DATA_SUCCESS = "Data has been inserted successfully."
DETOKENIZE_SUCCESS = "Data has been detokenized successfully."
GET_BY_ID_SUCCESS = "Data fetched from ID successfully."
QUERY_SUCCESS = "Query executed successfully."
BEARER_TOKEN_RECEIVED = "tokenProvider returned token successfully."
INSERT_TRIGGERED = "Insert method triggered."
DETOKENIZE_TRIGGERED = "Detokenize method triggered."
GET_BY_ID_TRIGGERED = "Get by ID triggered."
INVOKE_CONNECTION_TRIGGERED = "Invoke connection triggered."
QUERY_TRIGGERED = "Query method triggered."
GENERATE_BEARER_TOKEN_TRIGGERED = "Generate bearer token triggered"
GENERATE_BEARER_TOKEN_SUCCESS = "Generate bearer token returned successfully"
IS_TOKEN_VALID_TRIGGERED = "isTokenValid() triggered"
Expand All @@ -87,6 +89,7 @@ class InterfaceName(Enum):
GET = "client.get"
UPDATE = "client.update"
INVOKE_CONNECTION = "client.invoke_connection"
QUERY = "client.query"
GENERATE_BEARER_TOKEN = "service_account.generate_bearer_token"

IS_TOKEN_VALID = "service_account.isTokenValid"
Expand Down
5 changes: 5 additions & 0 deletions skyflow/errors/_skyflow_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class SkyflowErrorMessages(Enum):
INVALID_UPSERT_COLUMN_TYPE = "upsert object column key has value of type %s, expected string"
EMPTY_UPSERT_OPTION_TABLE = "upsert object table value is empty string at index %s, expected non-empty string"
EMPTY_UPSERT_OPTION_COLUMN = "upsert object column value is empty string at index %s, expected non-empty string"
QUERY_KEY_ERROR = "Query key is missing from payload"
INVALID_QUERY_TYPE = "Query key has value of type %s, expected string"
EMPTY_QUERY = "Query key cannot be empty"
INVALID_QUERY_COMMAND = "only SELECT commands are supported, %s command was passed instead"
SERVER_ERROR = "Server returned errors, check SkyflowError.data for more"

class SkyflowError(Exception):
def __init__(self, code, message="An Error occured", data={}, interface: str = 'Unknown') -> None:
Expand Down
27 changes: 24 additions & 3 deletions skyflow/vault/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from ._delete import deleteProcessResponse
from ._insert import getInsertRequestBody, processResponse, convertResponse
from ._update import sendUpdateRequests, createUpdateResponseBody
from ._config import Configuration, DeleteOptions
from ._config import InsertOptions, ConnectionConfig, UpdateOptions
from ._config import Configuration, DeleteOptions, InsertOptions, ConnectionConfig, UpdateOptions, QueryOptions
from ._connection import createRequest
from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody
from ._get_by_id import sendGetByIdRequests, createGetResponseBody
Expand All @@ -18,7 +17,7 @@
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import log_info, InfoMessages, InterfaceName, getMetrics
from ._token import tokenProviderWrapper

from ._query import getQueryRequestBody, getQueryResponse

class Client:
def __init__(self, config: Configuration):
Expand Down Expand Up @@ -141,6 +140,28 @@ def invoke_connection(self, config: ConnectionConfig):
session.close()
return processResponse(response, interface=interface)

def query(self, queryInput, options: QueryOptions = QueryOptions()):
interface = InterfaceName.QUERY.value
log_info(InfoMessages.QUERY_TRIGGERED.value, interface=interface)

self._checkConfig(interface)

jsonBody = getQueryRequestBody(queryInput, options)
requestURL = self._get_complete_vault_url() + "/query"
self.storedToken = tokenProviderWrapper(
self.storedToken, self.tokenProvider, interface)
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.storedToken,
"sky-metadata": json.dumps(getMetrics())
}

response = requests.post(requestURL, data=jsonBody, headers=headers)
result = getQueryResponse(response)

log_info(InfoMessages.QUERY_SUCCESS.value, interface)
return result

def _checkConfig(self, interface):
'''
Performs basic check on the given client config
Expand Down
4 changes: 4 additions & 0 deletions skyflow/vault/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(self, tokens: bool=True):
class DeleteOptions:
def __init__(self, tokens: bool=False):
self.tokens = tokens

class QueryOptions:
def __init__(self):
pass

class RequestMethod(Enum):
GET = 'GET'
Expand Down
62 changes: 62 additions & 0 deletions skyflow/vault/_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
'''
Copyright (c) 2022 Skyflow, Inc.
'''
import json

import requests
from ._config import QueryOptions
from requests.models import HTTPError
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow._utils import InterfaceName

interface = InterfaceName.QUERY.value


def getQueryRequestBody(data, options):
try:
query = data["query"]
except KeyError:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
SkyflowErrorMessages.QUERY_KEY_ERROR, interface=interface)

if not isinstance(query, str):
queryType = str(type(query))
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_QUERY_TYPE.value % queryType, interface=interface)

if not query.strip():
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,SkyflowErrorMessages.EMPTY_QUERY.value, interface=interface)

requestBody = {"query": query}
try:
jsonBody = json.dumps(requestBody)
except Exception as e:
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_JSON.value % (
'query payload'), interface=interface)

return jsonBody

def getQueryResponse(response: requests.Response, interface=interface):
statusCode = response.status_code
content = response.content.decode('utf-8')
try:
response.raise_for_status()
try:
return json.loads(content)
except:
raise SkyflowError(
statusCode, SkyflowErrorMessages.RESPONSE_NOT_JSON.value % content, interface=interface)
except HTTPError:
message = SkyflowErrorMessages.API_ERROR.value % statusCode
if response != None and response.content != None:
try:
errorResponse = json.loads(content)
if 'error' in errorResponse and type(errorResponse['error']) == type({}) and 'message' in errorResponse['error']:
message = errorResponse['error']['message']
except:
message = SkyflowErrorMessages.RESPONSE_NOT_JSON.value % content
raise SkyflowError(SkyflowErrorCodes.INVALID_INDEX, message, interface=interface)
error = {"error": {}}
if 'x-request-id' in response.headers:
message += ' - request id: ' + response.headers['x-request-id']
error['error'].update({"code": statusCode, "description": message})
raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, SkyflowErrorMessages.SERVER_ERROR.value, error, interface=interface)
175 changes: 175 additions & 0 deletions tests/vault/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
'''
Copyright (c) 2022 Skyflow, Inc.
'''
import json
import unittest
import os
from unittest import mock
import requests
from requests.models import Response
from skyflow.vault._query import getQueryRequestBody, getQueryResponse
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
from skyflow.vault._client import Client
from skyflow.vault._config import Configuration, QueryOptions

class TestQuery(unittest.TestCase):

def setUp(self) -> None:
self.dataPath = os.path.join(os.getcwd(), 'tests/vault/data/')
query = "SELECT * FROM pii_fields WHERE skyflow_id='3ea3861-x107-40w8-la98-106sp08ea83f'"
self.data = {"query": query}
self.mockRequest = {"records": [query]}

self.mockResponse = {
"records": [
{
"fields": {
"card_number": "XXXXXXXXXXXX1111",
"card_pin": "*REDACTED*",
"cvv": "",
"expiration_date": "*REDACTED*",
"expiration_month": "*REDACTED*",
"expiration_year": "*REDACTED*",
"name": "a***te",
"skyflow_id": "3ea3861-x107-40w8-la98-106sp08ea83f",
"ssn": "XXX-XX-6789",
"zip_code": None
},
"tokens": None
}
]
}

self.requestId = '5d5d7e21-c789-9fcc-ba31-2a279d3a28ef'

self.mockApiError = {
"error": {
"grpc_code": 13,
"http_code": 500,
"message": "ERROR (internal_error): Could not find Notebook Mapping Notebook Name was not found",
"http_status": "Internal Server Error",
"details": []
}
}

self.mockFailResponse = {
"error": {
"code": 500,
"description": "ERROR (internal_error): Could not find Notebook Mapping Notebook Name was not found - request id: 5d5d7e21-c789-9fcc-ba31-2a279d3a28ef"
}
}

self.queryOptions = QueryOptions()

return super().setUp()

def getDataPath(self, file):
return self.dataPath + file + '.json'

def testGetQueryRequestBodyWithValidBody(self):
body = json.loads(getQueryRequestBody(self.data, self.queryOptions))
expectedOutput = {
"query": "SELECT * FROM pii_fields WHERE skyflow_id='3ea3861-x107-40w8-la98-106sp08ea83f'",
}
self.assertEqual(body, expectedOutput)

def testGetQueryRequestBodyNoQuery(self):
invalidData = {"invalidKey": self.data["query"]}
try:
getQueryRequestBody(invalidData, self.queryOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.QUERY_KEY_ERROR.value)

def testGetQueryRequestBodyInvalidType(self):
invalidData = {"query": ['SELECT * FROM table_name']}
try:
getQueryRequestBody(invalidData, self.queryOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.INVALID_QUERY_TYPE.value % (str(type(invalidData["query"]))))

def testGetQueryRequestBodyEmptyBody(self):
invalidData = {"query": ''}
try:
getQueryRequestBody(invalidData, self.queryOptions)
self.fail('Should have thrown an error')
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.EMPTY_QUERY.value)

def testGetQueryValidResponse(self):
response = Response()
response.status_code = 200
response._content = b'{"key": "value"}'
try:
responseDict = getQueryResponse(response)
self.assertDictEqual(responseDict, {'key': 'value'})
except SkyflowError as e:
self.fail()

def testClientInit(self):
config = Configuration(
'vaultid', 'https://skyflow.com', lambda: 'test')
client = Client(config)
self.assertEqual(client.vaultURL, 'https://skyflow.com')
self.assertEqual(client.vaultID, 'vaultid')
self.assertEqual(client.tokenProvider(), 'test')

def testGetQueryResponseSuccessInvalidJson(self):
invalid_response = Response()
invalid_response.status_code = 200
invalid_response._content = b'invalid-json'
try:
getQueryResponse(invalid_response)
self.fail('not failing on invalid json')
except SkyflowError as se:
self.assertEqual(se.code, 200)
self.assertEqual(
se.message, SkyflowErrorMessages.RESPONSE_NOT_JSON.value % 'invalid-json')

def testGetQueryResponseFailInvalidJson(self):
invalid_response = mock.Mock(
spec=requests.Response,
status_code=404,
content=b'error'
)
invalid_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Not found")
try:
getQueryResponse(invalid_response)
self.fail('Not failing on invalid error json')
except SkyflowError as se:
self.assertEqual(se.code, 404)
self.assertEqual(
se.message, SkyflowErrorMessages.RESPONSE_NOT_JSON.value % 'error')

def testGetQueryResponseFail(self):
response = mock.Mock(
spec=requests.Response,
status_code=500,
content=json.dumps(self.mockApiError).encode('utf-8')
)
response.headers = {"x-request-id": self.requestId}
response.raise_for_status.side_effect = requests.exceptions.HTTPError("Server Error")
try:
getQueryResponse(response)
self.fail('not throwing exception when error code is 500')
except SkyflowError as e:
self.assertEqual(e.code, 500)
self.assertEqual(e.message, SkyflowErrorMessages.SERVER_ERROR.value)
self.assertDictEqual(e.data, self.mockFailResponse)

def testQueryInvalidToken(self):
config = Configuration('id', 'url', lambda: 'invalid-token')
try:
Client(config).query({'query': 'SELECT * FROM table_name'})
self.fail()
except SkyflowError as e:
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
self.assertEqual(
e.message, SkyflowErrorMessages.TOKEN_PROVIDER_INVALID_TOKEN.value)