-
Notifications
You must be signed in to change notification settings - Fork 6
Add method to populate the cache of the schema manually #360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
By default, schema.fetch will now populate the cache (this behavior can be changed with `populate_cache`) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add method `client.schema.set_cache()` to populate the cache manually (primarily for unit testing) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
The 'timeout' parameter while creating a node or fetching the schema has been deprecated. the default_timeout will be used instead. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
import asyncio | ||
import json | ||
import warnings | ||
from collections.abc import MutableMapping | ||
from enum import Enum | ||
from time import sleep | ||
|
@@ -90,6 +91,13 @@ | |
|
||
|
||
class InfrahubSchemaBase: | ||
client: InfrahubClient | InfrahubClientSync | ||
cache: dict[str, BranchSchema] | ||
|
||
def __init__(self, client: InfrahubClient | InfrahubClientSync): | ||
self.client = client | ||
self.cache = {} | ||
|
||
def validate(self, data: dict[str, Any]) -> None: | ||
SchemaRoot(**data) | ||
|
||
|
@@ -102,6 +110,23 @@ | |
message=f"{key} is not a valid value for {identifier}", | ||
) | ||
|
||
def set_cache(self, schema: dict[str, Any] | SchemaRootAPI | BranchSchema, branch: str | None = None) -> None: | ||
""" | ||
Set the cache manually (primarily for unit testing) | ||
|
||
Args: | ||
schema: The schema to set the cache as provided by the /api/schema endpoint either in dict or SchemaRootAPI format | ||
branch: The name of the branch to set the cache for. | ||
""" | ||
branch = branch or self.client.default_branch | ||
|
||
if isinstance(schema, SchemaRootAPI): | ||
schema = BranchSchema.from_schema_root_api(data=schema) | ||
elif isinstance(schema, dict): | ||
schema = BranchSchema.from_api_response(data=schema) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a way I'd prefer if .set_cache took a BranchSchema object directly. That way there's no confusion regarding what the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
self.cache[branch] = schema | ||
|
||
def generate_payload_create( | ||
self, | ||
schema: MainSchemaTypesAPI, | ||
|
@@ -187,11 +212,18 @@ | |
|
||
return data | ||
|
||
@staticmethod | ||
def _deprecated_schema_timeout() -> None: | ||
warnings.warn( | ||
"The 'timeout' parameter is deprecated while fetching the schema and will be removed version 2.0.0 of the Infrahub Python SDK. " | ||
"Use client.default_timeout instead.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will you also create an issue in the 2.0 milestone to remove this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do |
||
DeprecationWarning, | ||
stacklevel=2, | ||
) | ||
|
||
|
||
class InfrahubSchema(InfrahubSchemaBase): | ||
def __init__(self, client: InfrahubClient): | ||
self.client = client | ||
self.cache: dict[str, BranchSchema] = {} | ||
client: InfrahubClient | ||
|
||
async def get( | ||
self, | ||
|
@@ -204,16 +236,19 @@ | |
|
||
kind_str = self._get_schema_name(schema=kind) | ||
|
||
if timeout: | ||
self._deprecated_schema_timeout() | ||
|
||
if refresh: | ||
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout) | ||
self.cache[branch] = await self._fetch(branch=branch) | ||
|
||
if branch in self.cache and kind_str in self.cache[branch].nodes: | ||
return self.cache[branch].nodes[kind_str] | ||
|
||
# Fetching the latest schema from the server if we didn't fetch it earlier | ||
# because we coulnd't find the object on the local cache | ||
if not refresh: | ||
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout) | ||
self.cache[branch] = await self._fetch(branch=branch) | ||
|
||
if branch in self.cache and kind_str in self.cache[branch].nodes: | ||
return self.cache[branch].nodes[kind_str] | ||
|
@@ -416,59 +451,45 @@ | |
) | ||
|
||
async def fetch( | ||
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None | ||
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True | ||
) -> MutableMapping[str, MainSchemaTypesAPI]: | ||
"""Fetch the schema from the server for a given branch. | ||
|
||
Args: | ||
branch (str): Name of the branch to fetch the schema for. | ||
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds. | ||
branch: Name of the branch to fetch the schema for. | ||
timeout: Overrides default timeout used when querying the schema. deprecated. | ||
populate_cache: Whether to populate the cache with the fetched schema. Defaults to True. | ||
|
||
Returns: | ||
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind | ||
""" | ||
branch_schema = await self._fetch(branch=branch, namespaces=namespaces, timeout=timeout) | ||
|
||
if timeout: | ||
self._deprecated_schema_timeout() | ||
|
||
branch_schema = await self._fetch(branch=branch, namespaces=namespaces) | ||
|
||
if populate_cache: | ||
self.cache[branch] = branch_schema | ||
|
||
return branch_schema.nodes | ||
|
||
async def _fetch( | ||
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None | ||
) -> BranchSchema: | ||
async def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema: | ||
url_parts = [("branch", branch)] | ||
if namespaces: | ||
url_parts.extend([("namespaces", ns) for ns in namespaces]) | ||
query_params = urlencode(url_parts) | ||
url = f"{self.client.address}/api/schema?{query_params}" | ||
|
||
response = await self.client._get(url=url, timeout=timeout) | ||
response = await self.client._get(url=url) | ||
|
||
data = self._parse_schema_response(response=response, branch=branch) | ||
|
||
nodes: MutableMapping[str, MainSchemaTypesAPI] = {} | ||
for node_schema in data.get("nodes", []): | ||
node = NodeSchemaAPI(**node_schema) | ||
nodes[node.kind] = node | ||
|
||
for generic_schema in data.get("generics", []): | ||
generic = GenericSchemaAPI(**generic_schema) | ||
nodes[generic.kind] = generic | ||
|
||
for profile_schema in data.get("profiles", []): | ||
profile = ProfileSchemaAPI(**profile_schema) | ||
nodes[profile.kind] = profile | ||
|
||
for template_schema in data.get("templates", []): | ||
template = TemplateSchemaAPI(**template_schema) | ||
nodes[template.kind] = template | ||
|
||
schema_hash = data.get("main", "") | ||
|
||
return BranchSchema(hash=schema_hash, nodes=nodes) | ||
return BranchSchema.from_api_response(data=data) | ||
|
||
|
||
class InfrahubSchemaSync(InfrahubSchemaBase): | ||
def __init__(self, client: InfrahubClientSync): | ||
self.client = client | ||
self.cache: dict[str, BranchSchema] = {} | ||
client: InfrahubClientSync | ||
|
||
def all( | ||
self, | ||
|
@@ -506,10 +527,25 @@ | |
refresh: bool = False, | ||
timeout: int | None = None, | ||
) -> MainSchemaTypesAPI: | ||
""" | ||
Retrieve a specific schema object from the server. | ||
|
||
Args: | ||
kind: The kind of schema object to retrieve. | ||
branch: The branch to retrieve the schema from. | ||
refresh: Whether to refresh the schema. | ||
timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated). | ||
|
||
Returns: | ||
MainSchemaTypes: The schema object. | ||
""" | ||
branch = branch or self.client.default_branch | ||
|
||
kind_str = self._get_schema_name(schema=kind) | ||
|
||
if timeout: | ||
self._deprecated_schema_timeout() | ||
|
||
if refresh: | ||
self.cache[branch] = self._fetch(branch=branch) | ||
|
||
|
@@ -519,7 +555,7 @@ | |
# Fetching the latest schema from the server if we didn't fetch it earlier | ||
# because we coulnd't find the object on the local cache | ||
if not refresh: | ||
self.cache[branch] = self._fetch(branch=branch, timeout=timeout) | ||
self.cache[branch] = self._fetch(branch=branch) | ||
|
||
if branch in self.cache and kind_str in self.cache[branch].nodes: | ||
return self.cache[branch].nodes[kind_str] | ||
|
@@ -639,49 +675,39 @@ | |
) | ||
|
||
def fetch( | ||
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None | ||
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True | ||
) -> MutableMapping[str, MainSchemaTypesAPI]: | ||
"""Fetch the schema from the server for a given branch. | ||
|
||
Args: | ||
branch (str): Name of the branch to fetch the schema for. | ||
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds. | ||
branch: Name of the branch to fetch the schema for. | ||
timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated). | ||
populate_cache: Whether to populate the cache with the fetched schema. Defaults to True. | ||
|
||
Returns: | ||
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind | ||
""" | ||
branch_schema = self._fetch(branch=branch, namespaces=namespaces, timeout=timeout) | ||
if timeout: | ||
self._deprecated_schema_timeout() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also deprecate timeout in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's already covered because if someone use |
||
|
||
branch_schema = self._fetch(branch=branch, namespaces=namespaces) | ||
|
||
if populate_cache: | ||
self.cache[branch] = branch_schema | ||
|
||
return branch_schema.nodes | ||
|
||
def _fetch(self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None) -> BranchSchema: | ||
def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema: | ||
url_parts = [("branch", branch)] | ||
if namespaces: | ||
url_parts.extend([("namespaces", ns) for ns in namespaces]) | ||
query_params = urlencode(url_parts) | ||
url = f"{self.client.address}/api/schema?{query_params}" | ||
response = self.client._get(url=url, timeout=timeout) | ||
data = self._parse_schema_response(response=response, branch=branch) | ||
response = self.client._get(url=url) | ||
|
||
nodes: MutableMapping[str, MainSchemaTypesAPI] = {} | ||
for node_schema in data.get("nodes", []): | ||
node = NodeSchemaAPI(**node_schema) | ||
nodes[node.kind] = node | ||
|
||
for generic_schema in data.get("generics", []): | ||
generic = GenericSchemaAPI(**generic_schema) | ||
nodes[generic.kind] = generic | ||
|
||
for profile_schema in data.get("profiles", []): | ||
profile = ProfileSchemaAPI(**profile_schema) | ||
nodes[profile.kind] = profile | ||
|
||
for template_schema in data.get("templates", []): | ||
template = TemplateSchemaAPI(**template_schema) | ||
nodes[template.kind] = template | ||
|
||
schema_hash = data.get("main", "") | ||
data = self._parse_schema_response(response=response, branch=branch) | ||
|
||
return BranchSchema(hash=schema_hash, nodes=nodes) | ||
return BranchSchema.from_api_response(data=data) | ||
|
||
def load( | ||
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
set_cache
should be a private method as it's meant for internal use.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not meant for internal use, it's meant for user to use during unit tests