Skip to content

support static credentials #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,7 @@
"enum-compat>=0.0.1",
),
options={"bdist_wheel": {"universal": True}},
extras_require={
"yc": ["yandexcloud", ],
}
)
30 changes: 30 additions & 0 deletions ydb/_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
interceptor = None


_grpcs_protocol = "grpcs://"
_grpc_protocol = "grpc://"


def wrap_result_in_future(result):
f = futures.Future()
f.set_result(result)
Expand All @@ -33,6 +37,32 @@ def x_ydb_sdk_build_info_header():
return ("x-ydb-sdk-build-info", "ydb-python-sdk/" + ydb_version.VERSION)


def is_secure_protocol(endpoint):
return endpoint.startswith("grpcs://")


def wrap_endpoint(endpoint):
if endpoint.startswith(_grpcs_protocol):
return endpoint[len(_grpcs_protocol) :]
if endpoint.startswith(_grpc_protocol):
return endpoint[len(_grpc_protocol) :]
return endpoint


def parse_connection_string(connection_string):
cs = connection_string
if not cs.startswith(_grpc_protocol) and not cs.startswith(_grpcs_protocol):
# default is grpcs
cs = _grpcs_protocol + cs

p = six.moves.urllib.parse.urlparse(connection_string)
b = six.moves.urllib.parse.parse_qs(p.query)
database = b.get("database", [])
assert len(database) > 0

return p.scheme + "://" + p.netloc, database[0]


# Decorator that ensures no exceptions are leaked from decorated async call
def wrap_async_call_exceptions(f):
@functools.wraps(f)
Expand Down
133 changes: 12 additions & 121 deletions ydb/aio/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import time

import abc
import asyncio
import logging
import six
from ydb import issues, credentials
from ydb.iam import auth
from .credentials import AbstractExpiringTokenCredentials

logger = logging.getLogger(__name__)

Expand All @@ -25,127 +24,20 @@
aiohttp = None


class _OneToManyValue(object):
def __init__(self):
self._value = None
self._condition = asyncio.Condition()

async def consume(self, timeout=3):
async with self._condition:
if self._value is None:
try:
await asyncio.wait_for(self._condition.wait(), timeout=timeout)
except Exception:
return self._value
return self._value

async def update(self, n_value):
async with self._condition:
prev_value = self._value
self._value = n_value
if prev_value is None:
self._condition.notify_all()


class _AtMostOneExecution(object):
def __init__(self):
self._can_schedule = True
self._lock = asyncio.Lock() # Lock to guarantee only one execution

async def _wrapped_execution(self, callback):
await self._lock.acquire()
try:
res = callback()
if asyncio.iscoroutine(res):
await res
except Exception:
pass

finally:
self._lock.release()
self._can_schedule = True

def submit(self, callback):
if self._can_schedule:
self._can_schedule = False
asyncio.ensure_future(self._wrapped_execution(callback))


@six.add_metaclass(abc.ABCMeta)
class IamTokenCredentials(auth.IamTokenCredentials):
def __init__(self):
super(IamTokenCredentials, self).__init__()
self._tp = _AtMostOneExecution()
self._iam_token = _OneToManyValue()

@abc.abstractmethod
async def _get_iam_token(self):
pass

async def _refresh(self):
current_time = time.time()
self._log_refresh_start(current_time)

try:
auth_metadata = await self._get_iam_token()
await self._iam_token.update(auth_metadata["access_token"])
self.update_expiration_info(auth_metadata)
self.logger.info(
"Token refresh successful. current_time %s, refresh_in %s",
current_time,
self._refresh_in,
)

except (KeyboardInterrupt, SystemExit):
return

except Exception as e:
self.last_error = str(e)
await asyncio.sleep(1)
self._tp.submit(self._refresh)

async def iam_token(self):
current_time = time.time()
if current_time > self._refresh_in:
self._tp.submit(self._refresh)

iam_token = await self._iam_token.consume(timeout=3)
if iam_token is None:
if self.last_error is None:
raise issues.ConnectionError(
"%s: timeout occurred while waiting for token.\n%s"
% self.__class__.__name__,
self.extra_error_message,
)
raise issues.ConnectionError(
"%s: %s.\n%s"
% (self.__class__.__name__, self.last_error, self.extra_error_message)
)
return iam_token

async def auth_metadata(self):
return [(credentials.YDB_AUTH_TICKET_HEADER, await self.iam_token())]


@six.add_metaclass(abc.ABCMeta)
class TokenServiceCredentials(IamTokenCredentials):
class TokenServiceCredentials(AbstractExpiringTokenCredentials):
def __init__(self, iam_endpoint=None, iam_channel_credentials=None):
super(TokenServiceCredentials, self).__init__()
assert (
iam_token_service_pb2_grpc is not None
), "run pip install==ydb[yc] to use service account credentials"
self._get_token_request_timeout = 10
self._iam_endpoint = (
"iam.api.cloud.yandex.net:443" if iam_endpoint is None else iam_endpoint
)
self._iam_channel_credentials = (
{} if iam_channel_credentials is None else iam_channel_credentials
)
self._get_token_request_timeout = 10
if (
iam_token_service_pb2_grpc is None
or jwt is None
or iam_token_service_pb2 is None
):
raise RuntimeError(
"Install jwt & yandex python cloud library to use service account credentials provider"
)

def _channel_factory(self):
return grpc.aio.secure_channel(
Expand All @@ -157,7 +49,7 @@ def _channel_factory(self):
def _get_token_request(self):
pass

async def _get_iam_token(self):
async def _make_token_request(self):
async with self._channel_factory() as channel:
stub = iam_token_service_pb2_grpc.IamTokenServiceStub(channel)
response = await stub.Create(
Expand Down Expand Up @@ -209,20 +101,19 @@ def _get_token_request(self):
)


class MetadataUrlCredentials(IamTokenCredentials):
class MetadataUrlCredentials(AbstractExpiringTokenCredentials):
def __init__(self, metadata_url=None):
super(MetadataUrlCredentials, self).__init__()
if aiohttp is None:
raise RuntimeError(
"Install aiohttp library to use metadata credentials provider"
)
assert (
aiohttp is not None
), "Install aiohttp library to use metadata credentials provider"
self._metadata_url = (
auth.DEFAULT_METADATA_URL if metadata_url is None else metadata_url
)
self._tp.submit(self._refresh)
self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud."

async def _get_iam_token(self):
async def _make_token_request(self):
timeout = aiohttp.ClientTimeout(total=2)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(
Expand Down
3 changes: 2 additions & 1 deletion ydb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def _construct_metadata(driver_config, settings):
if driver_config.database is not None:
metadata.append((YDB_DATABASE_HEADER, driver_config.database))

if driver_config.credentials is not None:
need_rpc_auth = getattr(settings, "need_rpc_auth", True)
if driver_config.credentials is not None and need_rpc_auth:
metadata.extend(driver_config.credentials.auth_metadata())

if settings is not None:
Expand Down
Loading