Skip to content

Commit

Permalink
SDK - Components - Added ComponentStore search (#3884)
Browse files Browse the repository at this point in the history
* SDK - Components - Added ComponentStore search

ComponentStore(...).search searches for components by name in the configured component store. It prints the name and URL for components that match the given name.
Only components on GitHub are currently supported.

Example:

```
kfp.components.ComponentStore.default_store.search('xgboost')

>>> Xgboost train   https://raw.githubusercontent.com/.../components/XGBoost/Train/component.yaml
>>> Xgboost predict https://raw.githubusercontent.com/.../components/XGBoost/Predict/component.yaml
```

* Implemented the review feedback

* Added retries
  • Loading branch information
Ark-kun committed Jun 4, 2020
1 parent 04e23d2 commit 1403b9b
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 1 deletion.
141 changes: 140 additions & 1 deletion sdk/python/kfp/components/_component_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,19 @@

from pathlib import Path
import copy
import hashlib
import json
import logging
import requests
from typing import Callable
import tempfile
from typing import Callable, Iterable
from . import _components as comp
from .structures import ComponentReference
from ._key_value_store import KeyValueStore


_COMPONENT_FILENAME = 'component.yaml'


class ComponentStore:
def __init__(self, local_search_paths=None, url_search_prefixes=None):
Expand All @@ -18,6 +27,10 @@ def __init__(self, local_search_paths=None, url_search_prefixes=None):
self._digests_subpath = 'versions/sha256'
self._tags_subpath = 'versions/tags'

cache_base_dir = Path(tempfile.gettempdir()) / '.kfp_components'
self._git_blob_hash_to_data_db = KeyValueStore(cache_dir=cache_base_dir / 'git_blob_hash_to_data')
self._url_to_info_db = KeyValueStore(cache_dir=cache_base_dir / 'url_to_info')

def load_component_from_url(self, url):
return comp.load_component_from_url(url)

Expand Down Expand Up @@ -129,6 +142,132 @@ def _load_component_from_ref(self, component_ref: ComponentReference) -> Callabl
component_ref = self._load_component_spec_in_component_ref(component_ref)
return comp._create_task_factory_from_component_spec(component_spec=component_ref.spec, component_ref=component_ref)

def search(self, name: str):
'''Searches for components by name in the configured component store.
Prints the component name and URL for components that match the given name.
Only components on GitHub are currently supported.
Example::
kfp.components.ComponentStore.default_store.search('xgboost')
>>> Xgboost train https://raw.githubusercontent.com/.../components/XGBoost/Train/component.yaml
>>> Xgboost predict https://raw.githubusercontent.com/.../components/XGBoost/Predict/component.yaml
'''
self._refresh_component_cache()
for url in self._url_to_info_db.keys():
component_info = json.loads(self._url_to_info_db.try_get_value_bytes(url))
component_name = component_info['name']
if name.casefold() in component_name.casefold():
print('\t'.join([
component_name,
url,
]))

def list(self):
self.search('')

def _refresh_component_cache(self):
for url_search_prefix in self.url_search_prefixes:
if url_search_prefix.startswith('https://raw.githubusercontent.com/'):
logging.info('Searching for components in "{}"'.format(url_search_prefix))
for candidate in _list_candidate_component_uris_from_github_repo(url_search_prefix):
component_url = candidate['url']
if self._url_to_info_db.exists(component_url):
continue

logging.debug('Found new component URL: "{}"'.format(component_url))

blob_hash = candidate['git_blob_hash']
if not self._git_blob_hash_to_data_db.exists(blob_hash):
logging.debug('Downloading component spec from "{}"'.format(component_url))
response = _get_request_session().get(component_url)
response.raise_for_status()
component_data = response.content

# Verifying the hash
received_data_hash = _calculate_git_blob_hash(component_data)
if received_data_hash.lower() != blob_hash.lower():
raise RuntimeError(
'The downloaded component ({}) has incorrect hash: "{}" != "{}"'.format(
component_url, received_data_hash, blob_hash,
)
)

# Verifying that the component is loadable
try:
component_spec = comp._load_component_spec_from_component_text(component_data)
except:
continue
self._git_blob_hash_to_data_db.store_value_bytes(blob_hash, component_data)
else:
component_data = self._git_blob_hash_to_data_db.try_get_value_bytes(blob_hash)
component_spec = comp._load_component_spec_from_component_text(component_data)

component_name = component_spec.name
self._url_to_info_db.store_value_text(component_url, json.dumps(dict(
name=component_name,
url=component_url,
git_blob_hash=blob_hash,
digest=_calculate_component_digest(component_data),
)))


def _get_request_session(max_retries: int = 3):
session = requests.Session()

retry_strategy = requests.packages.urllib3.util.retry.Retry(
total=max_retries,
backoff_factor=0.1,
status_forcelist=[413, 429, 500, 502, 503, 504],
method_whitelist=frozenset(['GET', 'POST']),
)

session.mount('https://', requests.adapters.HTTPAdapter(max_retries=retry_strategy))
session.mount('http://', requests.adapters.HTTPAdapter(max_retries=retry_strategy))

return session


def _calculate_git_blob_hash(data: bytes) -> str:
return hashlib.sha1(b'blob ' + str(len(data)).encode('utf-8') + b'\x00' + data).hexdigest()


def _calculate_component_digest(data: bytes) -> str:
return hashlib.sha256(data.replace(b'\r\n', b'\n')).hexdigest()


def _list_candidate_component_uris_from_github_repo(url_search_prefix: str) -> Iterable[str]:
(schema, _, host, org, repo, ref, path_prefix) = url_search_prefix.split('/', 6)
for page in range(1, 999):
search_url = (
'https://api.github.com/search/code?q=filename:{}+repo:{}/{}&page={}&per_page=1000'
).format(_COMPONENT_FILENAME, org, repo, page)
response = _get_request_session().get(search_url)
response.raise_for_status()
result = response.json()
items = result['items']
if not items:
break
for item in items:
html_url = item['html_url']
# Constructing direct content URL
# There is an API (/repos/:owner/:repo/git/blobs/:file_sha) for
# getting the blob content, but it requires decoding the content.
raw_url = html_url.replace(
'https://github.com/', 'https://raw.githubusercontent.com/'
).replace('/blob/', '/', 1)
if not raw_url.endswith(_COMPONENT_FILENAME):
# GitHub matches component_test.yaml when searching for filename:"component.yaml"
continue
result_item = dict(
url=raw_url,
path = item['path'],
git_blob_hash = item['sha'],
)
yield result_item


ComponentStore.default_store = ComponentStore(
local_search_paths=[
Expand Down
62 changes: 62 additions & 0 deletions sdk/python/kfp/components/_key_value_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import hashlib
from pathlib import Path


class KeyValueStore:
KEY_FILE_SUFFIX = '.key'
VALUE_FILE_SUFFIX = '.value'

def __init__(
self,
cache_dir: str,
):
cache_dir = Path(cache_dir)
hash_func = (lambda text: hashlib.sha256(text.encode('utf-8')).hexdigest())
self.cache_dir = cache_dir
self.hash_func = hash_func

def store_value_text(self, key: str, text: str) -> str:
return self.store_value_bytes(key, text.encode('utf-8'))

def store_value_bytes(self, key: str, data: bytes) -> str:
cache_id = self.hash_func(key)
self.cache_dir.mkdir(parents=True, exist_ok=True)
cache_key_file_path = self.cache_dir / (cache_id + KeyValueStore.KEY_FILE_SUFFIX)
cache_value_file_path = self.cache_dir / (cache_id + KeyValueStore.VALUE_FILE_SUFFIX)
if cache_key_file_path.exists():
old_key = cache_key_file_path.read_text()
if key != old_key:
raise RuntimeError(
'Cache is corrupted: File "{}" contains existing key '
'"{}" != new key "{}"'.format(cache_key_file_path, old_key, key)
)
if cache_value_file_path.exists():
old_data = cache_value_file_path.write_bytes()
if data != old_data:
# TODO: Add options to raise error when overwriting the value.
pass
cache_value_file_path.write_bytes(data)
cache_key_file_path.write_text(key)
return cache_id

def try_get_value_text(self, key: str) -> str:
result = self.try_get_value_bytes(key)
if result is None:
return None
return result.decode('utf-8')

def try_get_value_bytes(self, key: str) -> bytes:
cache_id = self.hash_func(key)
cache_value_file_path = self.cache_dir / (cache_id + KeyValueStore.VALUE_FILE_SUFFIX)
if cache_value_file_path.exists():
return cache_value_file_path.read_bytes()
return None

def exists(self, key: str) -> bool:
cache_id = self.hash_func(key)
cache_key_file_path = self.cache_dir / (cache_id + KeyValueStore.KEY_FILE_SUFFIX)
return cache_key_file_path.exists()

def keys(self):
for cache_key_file_path in self.cache_dir.glob('*' + KeyValueStore.KEY_FILE_SUFFIX):
yield Path(cache_key_file_path).read_text()

0 comments on commit 1403b9b

Please sign in to comment.