Skip to content

Commit ec98d8f

Browse files
SK-2504: Add support for custom tokenUri (#228)
* SK-2504: add support for custom token uri
1 parent 49f6b8b commit ec98d8f

File tree

11 files changed

+454
-51
lines changed

11 files changed

+454
-51
lines changed

skyflow/service_account/_utils.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
import datetime
33
import time
44
import jwt
5+
from urllib.parse import urlparse
56
from skyflow.error import SkyflowError
67
from skyflow.service_account.client.auth_client import AuthClient
78
from skyflow.utils.logger import log_info, log_error_log
89
from skyflow.utils import get_base_url, format_scope, SkyflowMessages
10+
from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError
11+
from skyflow.utils import is_valid_url
912

1013

1114
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value
@@ -78,7 +81,14 @@ def get_service_account_token(credentials, options, logger):
7881
except:
7982
log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger)
8083
raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code)
81-
84+
85+
if not isinstance(token_uri, str) or not is_valid_url(token_uri):
86+
log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger)
87+
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code)
88+
89+
if options and "token_uri" in options:
90+
token_uri = options["token_uri"]
91+
8292
signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger)
8393
base_url = get_base_url(token_uri)
8494
auth_client = AuthClient(base_url)
@@ -88,10 +98,17 @@ def get_service_account_token(credentials, options, logger):
8898
if options and "role_ids" in options:
8999
formatted_scope = format_scope(options.get("role_ids"))
90100

91-
response = auth_api.authentication_service_get_auth_token(assertion = signed_token,
101+
try:
102+
response = auth_api.authentication_service_get_auth_token(assertion = signed_token,
92103
grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer",
93104
scope=formatted_scope)
94-
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger)
105+
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger)
106+
except UnauthorizedError:
107+
log_error_log(SkyflowMessages.ErrorLogs.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, logger=logger)
108+
raise SkyflowError(SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, invalid_input_error_code)
109+
except Exception:
110+
log_error_log(SkyflowMessages.ErrorLogs.FAILED_TO_GET_BEARER_TOKEN.value, logger=logger)
111+
raise SkyflowError(SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value, invalid_input_error_code)
95112
return response.access_token, response.token_type
96113

97114
def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger):
@@ -112,32 +129,41 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger):
112129

113130

114131
def get_signed_tokens(credentials_obj, options):
115-
try:
116-
expiry_time = int(time.time()) + options.get("time_to_live", 60)
117-
prefix = "signed_token_"
118-
119-
if options and options.get("data_tokens"):
120-
for token in options["data_tokens"]:
121-
claims = {
122-
"iss": "sdk",
123-
"key": credentials_obj.get("keyID"),
124-
"exp": expiry_time,
125-
"sub": credentials_obj.get("clientID"),
126-
"tok": token,
127-
"iat": int(time.time()),
128-
}
129-
130-
if "ctx" in options:
131-
claims["ctx"] = options["ctx"]
132-
133-
private_key = credentials_obj.get("privateKey")
132+
expiry_time = int(time.time()) + options.get("time_to_live", 60)
133+
prefix = "signed_token_"
134+
135+
token_uri = credentials_obj.get("tokenURI")
136+
if not isinstance(token_uri, str) or not is_valid_url(token_uri):
137+
log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value)
138+
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code)
139+
140+
if options and "token_uri" in options:
141+
token_uri = options["token_uri"]
142+
143+
144+
if options and options.get("data_tokens"):
145+
for token in options["data_tokens"]:
146+
claims = {
147+
"iss": "sdk",
148+
"key": credentials_obj.get("keyID"),
149+
"exp": expiry_time,
150+
"sub": credentials_obj.get("clientID"),
151+
"tok": token,
152+
"iat": int(time.time()),
153+
}
154+
155+
if "ctx" in options:
156+
claims["ctx"] = options["ctx"]
157+
158+
private_key = credentials_obj.get("privateKey")
159+
try:
134160
signed_jwt = jwt.encode(claims, private_key, algorithm="RS256")
135-
response_object = get_signed_data_token_response_object(prefix + signed_jwt, token)
136-
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value)
137-
return response_object
161+
except Exception:
162+
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
138163

139-
except Exception:
140-
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
164+
response_object = get_signed_data_token_response_object(prefix + signed_jwt, token)
165+
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value)
166+
return response_object
141167

142168

143169
def generate_signed_data_tokens(credentials_file_path, options):

skyflow/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ..utils.enums import LogLevel, Env, TokenType
22
from ._skyflow_messages import SkyflowMessages
33
from ._version import SDK_VERSION
4-
from ._helpers import get_base_url, format_scope
4+
from ._helpers import get_base_url, format_scope, is_valid_url
55
from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values, parse_deidentify_text_response, parse_reidentify_text_response, convert_detected_entity_to_entity_info

skyflow/utils/_helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,11 @@ def get_base_url(url):
88
def format_scope(scopes):
99
if not scopes:
1010
return None
11-
return " ".join([f"role:{scope}" for scope in scopes])
11+
return " ".join([f"role:{scope}" for scope in scopes])
12+
13+
def is_valid_url(url):
14+
try:
15+
result = urlparse(url)
16+
return all([result.scheme in ("http", "https"), result.netloc])
17+
except Exception:
18+
return False

skyflow/utils/_skyflow_messages.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,13 @@ class Error(Enum):
153153
MISSING_CLIENT_ID = f"{error_prefix} Initialization failed. Unable to read client ID in credentials. Verify your client ID."
154154
MISSING_KEY_ID = f"{error_prefix} Initialization failed. Unable to read key ID in credentials. Verify your key ID."
155155
MISSING_TOKEN_URI = f"{error_prefix} Initialization failed. Unable to read token URI in credentials. Verify your token URI."
156+
INVALID_TOKEN_URI = f"{error_prefix} Initialization failed. Invalid Skyflow credentials. The token URI must be a string and a valid URL."
156157
JWT_INVALID_FORMAT = f"{error_prefix} Initialization failed. Invalid private key format. Verify your credentials."
157158
JWT_DECODE_ERROR = f"{error_prefix} Validation error. Invalid access token. Verify your credentials."
158159
FILE_INVALID_JSON = f"{error_prefix} Initialization failed. File at {{}} is not in valid JSON format. Verify the file contents."
159160
INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV = f"{error_prefix} Validation error. Invalid JSON format in SKYFLOW_CREDENTIALS environment variable."
161+
FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token."
162+
UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token."
160163

161164
INVALID_TEXT_IN_DEIDENTIFY= f"{error_prefix} Validation error. The text field is required and must be a non-empty string. Specify a valid text."
162165
INVALID_ENTITIES_IN_DEIDENTIFY= f"{error_prefix} Validation error. The entities field must be an array of DetectEntities enums. Specify a valid entities."
@@ -332,6 +335,8 @@ class ErrorLogs(Enum):
332335
KEY_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Key ID is required."
333336
TOKEN_URI_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Token URI is required."
334337
INVALID_TOKEN_URI = f"{ERROR}: [{error_prefix}] Invalid value for token URI in credentials."
338+
FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token."
339+
UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token."
335340

336341

337342
TABLE_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Table is required."

skyflow/utils/validations/_validations.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \
1111
GetDetectRunRequest, Bleep, DeidentifyFileRequest
1212
from skyflow.vault.detect._file_input import FileInput
13+
from skyflow.utils._helpers import is_valid_url
1314

1415
valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"]
1516
valid_connection_config_keys = ["connection_id", "connection_url", "credentials"]
@@ -138,6 +139,15 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non
138139
raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id)
139140
if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value,
140141
invalid_input_error_code)
142+
143+
if "token_uri" in credentials:
144+
token_uri = credentials.get("token_uri")
145+
if (
146+
token_uri is None
147+
or not isinstance(token_uri, str)
148+
or not is_valid_url(token_uri)
149+
):
150+
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code)
141151

142152
def validate_log_level(logger, log_level):
143153
if not isinstance(log_level, LogLevel):
@@ -202,10 +212,8 @@ def validate_update_vault_config(logger, config):
202212
if "env" in config and config.get("env") not in Env:
203213
raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code)
204214

205-
if "credentials" not in config:
206-
raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code)
207-
208-
validate_credentials(logger, config.get("credentials"), "vault", vault_id)
215+
if "credentials" in config and config.get("credentials"):
216+
validate_credentials(logger, config.get("credentials"), "vault", vault_id)
209217

210218
return True
211219

@@ -413,9 +421,6 @@ def validate_insert_request(logger, request):
413421
if key is None or key == "":
414422
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format("INSERT"), logger = logger)
415423

416-
if value is None or value == "":
417-
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format("INSERT", key), logger = logger)
418-
419424
if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()):
420425
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value("INSERT"), logger = logger)
421426
raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code)

skyflow/vault/client/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from skyflow.error import SkyflowError
12
from skyflow.generated.rest.client import Skyflow
23
from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired
34
from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages
@@ -62,6 +63,8 @@ def get_bearer_token(self, credentials):
6263
"role_ids": self.__config.get("roles"),
6364
"ctx": self.__config.get("ctx")
6465
}
66+
if "token_uri" in credentials and credentials.get("token_uri"):
67+
options["token_uri"] = credentials.get("token_uri")
6568

6669
if self.__bearer_token is None or self.__is_config_updated:
6770
if 'path' in credentials:
@@ -85,7 +88,7 @@ def get_bearer_token(self, credentials):
8588

8689
if is_expired(self.__bearer_token):
8790
self.__is_config_updated = True
88-
raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
91+
raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
8992

9093
return self.__bearer_token
9194

tests/service_account/test__utils.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,73 @@ def test_generate_signed_data_tokens_from_creds_with_invalid_string(self):
143143
credentials_string = '{'
144144
with self.assertRaises(SkyflowError) as context:
145145
result = generate_signed_data_tokens_from_creds(credentials_string, options)
146-
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value)
146+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value)
147+
148+
@patch("skyflow.service_account._utils.AuthClient")
149+
@patch("skyflow.service_account._utils.get_signed_jwt")
150+
def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_signed_jwt, mock_auth_client):
151+
creds = {
152+
"privateKey": "private_key",
153+
"clientID": "client_id",
154+
"keyID": "key_id",
155+
"tokenURI": "https://valid-url.com"
156+
}
157+
options = {"role_ids": ["role1", "role2"]}
158+
mock_get_signed_jwt.return_value = "signed"
159+
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
160+
mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {"access_token": "token",
161+
"token_type": "bearer"})
162+
access_token, token_type = get_service_account_token(creds, options, None)
163+
self.assertEqual(access_token, "token")
164+
self.assertEqual(token_type, "bearer")
165+
args, kwargs = mock_auth_api.authentication_service_get_auth_token.call_args
166+
self.assertIn("scope", kwargs)
167+
self.assertEqual(kwargs["scope"], "role:role1 role:role2")
168+
169+
@patch("skyflow.service_account._utils.AuthClient")
170+
@patch("skyflow.service_account._utils.get_signed_jwt")
171+
def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt, mock_auth_client):
172+
creds = {
173+
"privateKey": "private_key",
174+
"clientID": "client_id",
175+
"keyID": "key_id",
176+
"tokenURI": "https://valid-url.com"
177+
}
178+
mock_get_signed_jwt.return_value = "signed"
179+
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
180+
from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError
181+
mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized")
182+
with self.assertRaises(SkyflowError) as context:
183+
get_service_account_token(creds, {}, None)
184+
self.assertEqual(context.exception.message,
185+
SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value)
186+
187+
@patch("skyflow.service_account._utils.AuthClient")
188+
@patch("skyflow.service_account._utils.get_signed_jwt")
189+
def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, mock_auth_client):
190+
creds = {
191+
"privateKey": "private_key",
192+
"clientID": "client_id",
193+
"keyID": "key_id",
194+
"tokenURI": "https://valid-url.com"
195+
}
196+
mock_get_signed_jwt.return_value = "signed"
197+
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
198+
mock_auth_api.authentication_service_get_auth_token.side_effect = Exception("some error")
199+
with self.assertRaises(SkyflowError) as context:
200+
get_service_account_token(creds, {}, None)
201+
self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value)
202+
203+
@patch("jwt.encode", side_effect=Exception("jwt error"))
204+
def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode):
205+
creds = {
206+
"privateKey": "private_key",
207+
"clientID": "client_id",
208+
"keyID": "key_id",
209+
"tokenURI": "https://valid-url.com"
210+
}
211+
options = {"data_tokens": ["token1"]}
212+
with self.assertRaises(SkyflowError) as context:
213+
from skyflow.service_account._utils import get_signed_tokens
214+
get_signed_tokens(creds, options)
215+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value)

tests/utils/test__helpers.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from skyflow.utils import get_base_url, format_scope
2+
from skyflow.utils import get_base_url, format_scope, is_valid_url
33

44
VALID_URL = "https://example.com/path?query=1"
55
BASE_URL = "https://example.com"
@@ -35,4 +35,27 @@ def test_format_scope_single_scope(self):
3535
def test_format_scope_special_characters(self):
3636
scopes_with_special_chars = ["admin", "user:write", "read-only"]
3737
expected_result = "role:admin role:user:write role:read-only"
38-
self.assertEqual(format_scope(scopes_with_special_chars), expected_result)
38+
self.assertEqual(format_scope(scopes_with_special_chars), expected_result)
39+
40+
def test_is_valid_url_valid(self):
41+
self.assertTrue(is_valid_url("https://example.com"))
42+
self.assertTrue(is_valid_url("http://example.com/path"))
43+
44+
def test_is_valid_url_invalid(self):
45+
self.assertFalse(is_valid_url("ftp://example.com"))
46+
self.assertFalse(is_valid_url("example.com"))
47+
self.assertFalse(is_valid_url("invalid-url"))
48+
self.assertFalse(is_valid_url(""))
49+
50+
def test_is_valid_url_none(self):
51+
self.assertFalse(is_valid_url(None))
52+
53+
def test_is_valid_url_no_scheme(self):
54+
self.assertFalse(is_valid_url("www.example.com"))
55+
56+
def test_is_valid_url_exception(self):
57+
class BadStr:
58+
def __str__(self):
59+
raise Exception("bad str")
60+
61+
self.assertFalse(is_valid_url(BadStr()))

0 commit comments

Comments
 (0)