Skip to content

Add method to populate the cache of the schema manually #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/+schema-fetch.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
By default, schema.fetch will now populate the cache (this behavior can be changed with `populate_cache`)
1 change: 1 addition & 0 deletions changelog/+schema-set-cache.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add method `client.schema.set_cache()` to populate the cache manually (primarily for unit testing)
1 change: 1 addition & 0 deletions changelog/+schema-timeout.deprecated.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The 'timeout' parameter while creating a node or fetching the schema has been deprecated. the default_timeout will be used instead.
2 changes: 1 addition & 1 deletion infrahub_sdk/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def client(self, value: InfrahubClient) -> None:
async def init(cls, client: InfrahubClient | None = None, *args: Any, **kwargs: Any) -> InfrahubCheck:
"""Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically."""
warnings.warn(
"InfrahubCheck.init has been deprecated and will be removed in the version in Infrahub SDK 2.0.0",
"InfrahubCheck.init has been deprecated and will be removed in version 2.0.0 of the Infrahub Python SDK",
DeprecationWarning,
stacklevel=1,
)
Expand Down
152 changes: 89 additions & 63 deletions infrahub_sdk/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
import warnings
from collections.abc import MutableMapping
from enum import Enum
from time import sleep
Expand Down Expand Up @@ -90,6 +91,13 @@


class InfrahubSchemaBase:
client: InfrahubClient | InfrahubClientSync
cache: dict[str, BranchSchema]

def __init__(self, client: InfrahubClient | InfrahubClientSync):
self.client = client
self.cache = {}

def validate(self, data: dict[str, Any]) -> None:
SchemaRoot(**data)

Expand All @@ -102,6 +110,23 @@
message=f"{key} is not a valid value for {identifier}",
)

def set_cache(self, schema: dict[str, Any] | SchemaRootAPI | BranchSchema, branch: str | None = None) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think set_cache should be a private method as it's meant for internal use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not meant for internal use, it's meant for user to use during unit tests

"""
Set the cache manually (primarily for unit testing)

Args:
schema: The schema to set the cache as provided by the /api/schema endpoint either in dict or SchemaRootAPI format
branch: The name of the branch to set the cache for.
"""
branch = branch or self.client.default_branch

if isinstance(schema, SchemaRootAPI):
schema = BranchSchema.from_schema_root_api(data=schema)

Check warning on line 124 in infrahub_sdk/schema/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema/__init__.py#L124

Added line #L124 was not covered by tests
elif isinstance(schema, dict):
schema = BranchSchema.from_api_response(data=schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a way I'd prefer if .set_cache took a BranchSchema object directly. That way there's no confusion regarding what the dict[str, Any] should look like. It would mostly be relevant if there are any errors in the schema, then those errors would be raised in the correct location. On the other hand it might be slightly easier for some users if they don't have to convert it into a dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BranchSchema is primarily an internal object so in most cases I would expect users to use the dict most of the time.
To clarify, I've added support for SchemaRootAPI as well which is the format returned by the /api/schema endpoint


self.cache[branch] = schema

def generate_payload_create(
self,
schema: MainSchemaTypesAPI,
Expand Down Expand Up @@ -187,11 +212,18 @@

return data

@staticmethod
def _deprecated_schema_timeout() -> None:
warnings.warn(

Check warning on line 217 in infrahub_sdk/schema/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema/__init__.py#L217

Added line #L217 was not covered by tests
"The 'timeout' parameter is deprecated while fetching the schema and will be removed version 2.0.0 of the Infrahub Python SDK. "
"Use client.default_timeout instead.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you also create an issue in the 2.0 milestone to remove this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

DeprecationWarning,
stacklevel=2,
)


class InfrahubSchema(InfrahubSchemaBase):
def __init__(self, client: InfrahubClient):
self.client = client
self.cache: dict[str, BranchSchema] = {}
client: InfrahubClient

async def get(
self,
Expand All @@ -204,16 +236,19 @@

kind_str = self._get_schema_name(schema=kind)

if timeout:
self._deprecated_schema_timeout()

Check warning on line 240 in infrahub_sdk/schema/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema/__init__.py#L240

Added line #L240 was not covered by tests

if refresh:
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
self.cache[branch] = await self._fetch(branch=branch)

Check warning on line 243 in infrahub_sdk/schema/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema/__init__.py#L243

Added line #L243 was not covered by tests

if branch in self.cache and kind_str in self.cache[branch].nodes:
return self.cache[branch].nodes[kind_str]

# Fetching the latest schema from the server if we didn't fetch it earlier
# because we coulnd't find the object on the local cache
if not refresh:
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
self.cache[branch] = await self._fetch(branch=branch)

if branch in self.cache and kind_str in self.cache[branch].nodes:
return self.cache[branch].nodes[kind_str]
Expand Down Expand Up @@ -416,59 +451,45 @@
)

async def fetch(
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True
) -> MutableMapping[str, MainSchemaTypesAPI]:
"""Fetch the schema from the server for a given branch.

Args:
branch (str): Name of the branch to fetch the schema for.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
branch: Name of the branch to fetch the schema for.
timeout: Overrides default timeout used when querying the schema. deprecated.
populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.

Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
branch_schema = await self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)

if timeout:
self._deprecated_schema_timeout()

Check warning on line 468 in infrahub_sdk/schema/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema/__init__.py#L468

Added line #L468 was not covered by tests

branch_schema = await self._fetch(branch=branch, namespaces=namespaces)

if populate_cache:
self.cache[branch] = branch_schema

return branch_schema.nodes

async def _fetch(
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
) -> BranchSchema:
async def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema:
url_parts = [("branch", branch)]
if namespaces:
url_parts.extend([("namespaces", ns) for ns in namespaces])
query_params = urlencode(url_parts)
url = f"{self.client.address}/api/schema?{query_params}"

response = await self.client._get(url=url, timeout=timeout)
response = await self.client._get(url=url)

data = self._parse_schema_response(response=response, branch=branch)

nodes: MutableMapping[str, MainSchemaTypesAPI] = {}
for node_schema in data.get("nodes", []):
node = NodeSchemaAPI(**node_schema)
nodes[node.kind] = node

for generic_schema in data.get("generics", []):
generic = GenericSchemaAPI(**generic_schema)
nodes[generic.kind] = generic

for profile_schema in data.get("profiles", []):
profile = ProfileSchemaAPI(**profile_schema)
nodes[profile.kind] = profile

for template_schema in data.get("templates", []):
template = TemplateSchemaAPI(**template_schema)
nodes[template.kind] = template

schema_hash = data.get("main", "")

return BranchSchema(hash=schema_hash, nodes=nodes)
return BranchSchema.from_api_response(data=data)


class InfrahubSchemaSync(InfrahubSchemaBase):
def __init__(self, client: InfrahubClientSync):
self.client = client
self.cache: dict[str, BranchSchema] = {}
client: InfrahubClientSync

def all(
self,
Expand Down Expand Up @@ -506,10 +527,25 @@
refresh: bool = False,
timeout: int | None = None,
) -> MainSchemaTypesAPI:
"""
Retrieve a specific schema object from the server.

Args:
kind: The kind of schema object to retrieve.
branch: The branch to retrieve the schema from.
refresh: Whether to refresh the schema.
timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).

Returns:
MainSchemaTypes: The schema object.
"""
branch = branch or self.client.default_branch

kind_str = self._get_schema_name(schema=kind)

if timeout:
self._deprecated_schema_timeout()

Check warning on line 547 in infrahub_sdk/schema/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema/__init__.py#L547

Added line #L547 was not covered by tests

if refresh:
self.cache[branch] = self._fetch(branch=branch)

Expand All @@ -519,7 +555,7 @@
# Fetching the latest schema from the server if we didn't fetch it earlier
# because we coulnd't find the object on the local cache
if not refresh:
self.cache[branch] = self._fetch(branch=branch, timeout=timeout)
self.cache[branch] = self._fetch(branch=branch)

if branch in self.cache and kind_str in self.cache[branch].nodes:
return self.cache[branch].nodes[kind_str]
Expand Down Expand Up @@ -639,49 +675,39 @@
)

def fetch(
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True
) -> MutableMapping[str, MainSchemaTypesAPI]:
"""Fetch the schema from the server for a given branch.

Args:
branch (str): Name of the branch to fetch the schema for.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
branch: Name of the branch to fetch the schema for.
timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).
populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.

Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
branch_schema = self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
if timeout:
self._deprecated_schema_timeout()

Check warning on line 691 in infrahub_sdk/schema/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema/__init__.py#L691

Added line #L691 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also deprecate timeout in the client.create() methods it's only used to pass the argument to the schema function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already covered because if someone use timeout while creating a node, it will generate a deprecation warning.
Having said that I've updated the changelog to reflect that properly


branch_schema = self._fetch(branch=branch, namespaces=namespaces)

if populate_cache:
self.cache[branch] = branch_schema

return branch_schema.nodes

def _fetch(self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None) -> BranchSchema:
def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema:
url_parts = [("branch", branch)]
if namespaces:
url_parts.extend([("namespaces", ns) for ns in namespaces])
query_params = urlencode(url_parts)
url = f"{self.client.address}/api/schema?{query_params}"
response = self.client._get(url=url, timeout=timeout)
data = self._parse_schema_response(response=response, branch=branch)
response = self.client._get(url=url)

nodes: MutableMapping[str, MainSchemaTypesAPI] = {}
for node_schema in data.get("nodes", []):
node = NodeSchemaAPI(**node_schema)
nodes[node.kind] = node

for generic_schema in data.get("generics", []):
generic = GenericSchemaAPI(**generic_schema)
nodes[generic.kind] = generic

for profile_schema in data.get("profiles", []):
profile = ProfileSchemaAPI(**profile_schema)
nodes[profile.kind] = profile

for template_schema in data.get("templates", []):
template = TemplateSchemaAPI(**template_schema)
nodes[template.kind] = template

schema_hash = data.get("main", "")
data = self._parse_schema_response(response=response, branch=branch)

return BranchSchema(hash=schema_hash, nodes=nodes)
return BranchSchema.from_api_response(data=data)

def load(
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False
Expand Down
32 changes: 31 additions & 1 deletion infrahub_sdk/schema/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Any, Union

from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Self

if TYPE_CHECKING:
from ..node import InfrahubNode, InfrahubNodeSync
Expand Down Expand Up @@ -344,7 +345,7 @@ def to_schema_dict(self) -> dict[str, Any]:
class SchemaRootAPI(BaseModel):
model_config = ConfigDict(use_enum_values=True)

version: str
main: str | None = None
generics: list[GenericSchemaAPI] = Field(default_factory=list)
nodes: list[NodeSchemaAPI] = Field(default_factory=list)
profiles: list[ProfileSchemaAPI] = Field(default_factory=list)
Expand All @@ -356,3 +357,32 @@ class BranchSchema(BaseModel):
nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = Field(
default_factory=dict
)

@classmethod
def from_api_response(cls, data: MutableMapping[str, Any]) -> Self:
"""
Convert an API response from /api/schema into a BranchSchema object.
"""
return cls.from_schema_root_api(data=SchemaRootAPI(**data))

@classmethod
def from_schema_root_api(cls, data: SchemaRootAPI) -> Self:
"""
Convert a SchemaRootAPI object to a BranchSchema object.
"""
nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = {}
for node in data.nodes:
nodes[node.kind] = node

for generic in data.generics:
nodes[generic.kind] = generic

for profile in data.profiles:
nodes[profile.kind] = profile

for template in data.templates:
nodes[template.kind] = template

schema_hash = data.main or ""

return cls(hash=schema_hash, nodes=nodes)
Loading