Skip to content

Create attribute based NumberPools when the schema is updated #6618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion backend/infrahub/pools/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from prefect.logging import get_run_logger

from infrahub.context import InfrahubContext # noqa: TC001 needed for prefect flow
from infrahub.core.constants import NumberPoolType
from infrahub.core.constants import InfrahubKind, NumberPoolType
from infrahub.core.manager import NodeManager
from infrahub.core.node import Node
from infrahub.core.protocols import CoreNumberPool
from infrahub.core.registry import registry
from infrahub.core.schema.attribute_parameters import NumberPoolParameters
Expand Down Expand Up @@ -54,3 +55,62 @@ async def validate_schema_number_pools(
elif not defined_on_branches:
log.info(f"Deleting number pool (id={schema_number_pool.id}) as it is no longer defined in the schema")
await schema_number_pool.delete(db=service.database)

existing_pool_ids = [pool.id for pool in schema_number_pools]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how big a change it is but I think it would be better to let the id being set when we create the number pool as for a regular node instead of setting it during validate_attribute_parameters. Is there still a value to have it set earlier?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said maybe I miss something because I am not sure where the number pool is currently created (regardless of this PR). As it seems we now want to create schema number pools during schema validation, should we remove the code responsible for creating it in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done like this with the ID because we need the ID to be in the schema. I think timing wise it would be better to keep the other creation as well as we can't be sure of the order of operations here. I plan to open a new PR later with a lock around the two of them though.

for registry_branch in registry.schema.get_branches():
schema_branch = service.database.schema.get_schema_branch(name=registry_branch)

for generic_name in schema_branch.generic_names:
generic_node = schema_branch.get_generic(name=generic_name, duplicate=False)
for attribute_name in generic_node.attribute_names:
attribute = generic_node.get_attribute(name=attribute_name)
if isinstance(attribute.parameters, NumberPoolParameters) and attribute.parameters.number_pool_id:
if attribute.parameters.number_pool_id not in existing_pool_ids:
await _create_number_pool(
service=service,
number_pool_id=attribute.parameters.number_pool_id,
pool_node=generic_node.kind,
pool_attribute=attribute_name,
start_range=attribute.parameters.start_range,
end_range=attribute.parameters.end_range,
)
existing_pool_ids.append(attribute.parameters.number_pool_id)

for node_name in schema_branch.node_names:
node = schema_branch.get_node(name=node_name, duplicate=False)
for attribute_name in node.attribute_names:
attribute = node.get_attribute(name=attribute_name)
if isinstance(attribute.parameters, NumberPoolParameters) and attribute.parameters.number_pool_id:
if attribute.parameters.number_pool_id not in existing_pool_ids:
await _create_number_pool(
service=service,
number_pool_id=attribute.parameters.number_pool_id,
pool_node=node.kind,
pool_attribute=attribute_name,
start_range=attribute.parameters.start_range,
end_range=attribute.parameters.end_range,
)
existing_pool_ids.append(attribute.parameters.number_pool_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we factorize above blocks in a single one? Even if it means calling SchemaManager.get

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now the above might actually be a bit cleaner just due to the name of the kind as we'd need to find the name of the generic that defined the attribute in the case of a node_schema. I'll take another look at this later.



async def _create_number_pool(
service: InfrahubServices,
number_pool_id: str,
pool_node: str,
pool_attribute: str,
start_range: int,
end_range: int,
) -> None:
async with service.database.start_session() as dbs:
number_pool = await Node.init(db=dbs, schema=InfrahubKind.NUMBERPOOL, branch=registry.default_branch)
await number_pool.new(
db=dbs,
id=number_pool_id,
name=f"{pool_node}.{pool_attribute} [{number_pool_id}]",
node=pool_node,
node_attribute=pool_attribute,
start_range=start_range,
end_range=end_range,
pool_type=NumberPoolType.SCHEMA.value,
)
await number_pool.save(db=dbs)
71 changes: 64 additions & 7 deletions backend/tests/functional/pools/test_numberpool_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from infrahub.pools.tasks import validate_schema_number_pools
from infrahub.services import InfrahubServices
from infrahub.services.adapters.cache.redis import RedisCache
from tests.helpers.schema.snow import SNOW_INCIDENT, SNOW_REQUEST, SNOW_TASK
from tests.helpers.test_app import TestInfrahubApp

if TYPE_CHECKING:
Expand Down Expand Up @@ -43,7 +44,7 @@
}


class TestMutationGenerator(TestInfrahubApp):
class TestAttributeNumberPoolLifecycle(TestInfrahubApp):
@pytest.fixture(scope="class")
async def initial_dataset(
self,
Expand All @@ -56,18 +57,24 @@ async def initial_dataset(
) -> None:
bus_simulator.service._cache = RedisCache()

schema = {"version": "1.0", "nodes": [node_schema_definition]}
schema = {
"version": "1.0",
"generics": [SNOW_TASK.to_dict()],
"nodes": [node_schema_definition, SNOW_INCIDENT.to_dict(), SNOW_REQUEST.to_dict()],
}
schema_load_response = await client.schema.load(schemas=[schema], wait_until_converged=True)
assert not schema_load_response.errors

async def test_numberpool_assignment(
async def test_numberpool_assignment_direct_node(
self, db: InfrahubDatabase, initial_dataset: None, client: InfrahubClient, default_branch
) -> None:
assert True
service = await InfrahubServices.new(database=db)
context = InfrahubContext.init(
branch=default_branch,
account=AccountSession(auth_type=AuthType.NONE, authenticated=False, account_id=""),
)

incident_1 = await Node.init(db=db, schema="TestNumberAttribute")
await incident_1.new(db=db, name="The first thing")
await incident_1.save(db=db)
await validate_schema_number_pools(branch_name=registry.default_branch, context=context, service=service)

test_schema = registry.schema.get_node_schema(name="TestNumberAttribute")
test_attribute = test_schema.get_attribute(name="assigned_number")
Expand All @@ -81,6 +88,10 @@ async def test_numberpool_assignment(
)
assert number_pool_pre.start_range.value == 10

incident_1 = await Node.init(db=db, schema="TestNumberAttribute")
await incident_1.new(db=db, name="The first thing")
await incident_1.save(db=db)

initial_branches = get_branches_with_schema_number_pool(
kind="TestNumberAttribute", attribute_name="assigned_number"
)
Expand All @@ -94,6 +105,18 @@ async def test_numberpool_assignment(
after_purge = get_branches_with_schema_number_pool(kind="TestNumberAttribute", attribute_name="assigned_number")
assert after_purge == []

await validate_schema_number_pools(branch_name=registry.default_branch, context=context, service=service)

with pytest.raises(NodeNotFoundError):
await NodeManager.find_object(
db=db,
kind=CoreNumberPool,
id=number_pool_id,
)

async def test_numberpool_assignment_from_generic(
self, db: InfrahubDatabase, initial_dataset: None, client: InfrahubClient, default_branch
) -> None:
service = await InfrahubServices.new(database=db)
context = InfrahubContext.init(
branch=default_branch,
Expand All @@ -102,6 +125,40 @@ async def test_numberpool_assignment(

await validate_schema_number_pools(branch_name=registry.default_branch, context=context, service=service)

test_schema = registry.schema.get_node_schema(name="SnowIncident")
test_attribute = test_schema.get_attribute(name="number")
assert isinstance(test_attribute.parameters, NumberPoolParameters)
number_pool_id = test_attribute.parameters.number_pool_id

number_pool_pre = await NodeManager.find_object(
db=db,
kind=CoreNumberPool,
id=number_pool_id,
)
assert number_pool_pre.start_range.value == 1

incident_1 = await Node.init(db=db, schema="SnowIncident")
await incident_1.new(db=db, title="The very first incident")
await incident_1.save(db=db)

initial_branches = get_branches_with_schema_number_pool(kind="SnowTask", attribute_name="number")

assert initial_branches == ["main"]
snow_task = SNOW_TASK.to_dict()
snow_task["state"] = "absent"
snow_request = SNOW_REQUEST.to_dict()
snow_request["state"] = "absent"
snow_incident = SNOW_INCIDENT.to_dict()
snow_incident["state"] = "absent"
schema = {"version": "1.0", "generics": [snow_task], "nodes": [snow_request, snow_incident]}
schema_load_response = await client.schema.load(schemas=[schema], wait_until_converged=True)
assert not schema_load_response.errors

after_purge = get_branches_with_schema_number_pool(kind="SnowTask", attribute_name="number")
assert after_purge == []

await validate_schema_number_pools(branch_name=registry.default_branch, context=context, service=service)

with pytest.raises(NodeNotFoundError):
await NodeManager.find_object(
db=db,
Expand Down
Loading