Skip to content

Commit eb242e8

Browse files
committed
Create attribute based NumberPools when the schema is updated
1 parent 13bb9f4 commit eb242e8

File tree

2 files changed

+125
-8
lines changed

2 files changed

+125
-8
lines changed

backend/infrahub/pools/tasks.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from prefect.logging import get_run_logger
55

66
from infrahub.context import InfrahubContext # noqa: TC001 needed for prefect flow
7-
from infrahub.core.constants import NumberPoolType
7+
from infrahub.core.constants import InfrahubKind, NumberPoolType
88
from infrahub.core.manager import NodeManager
9+
from infrahub.core.node import Node
910
from infrahub.core.protocols import CoreNumberPool
1011
from infrahub.core.registry import registry
1112
from infrahub.core.schema.attribute_parameters import NumberPoolParameters
@@ -54,3 +55,62 @@ async def validate_schema_number_pools(
5455
elif not defined_on_branches:
5556
log.info(f"Deleting number pool (id={schema_number_pool.id}) as it is no longer defined in the schema")
5657
await schema_number_pool.delete(db=service.database)
58+
59+
existing_pool_ids = [pool.id for pool in schema_number_pools]
60+
for registry_branch in registry.schema.get_branches():
61+
schema_branch = service.database.schema.get_schema_branch(name=registry_branch)
62+
63+
for generic_name in schema_branch.generic_names:
64+
generic_node = schema_branch.get_generic(name=generic_name, duplicate=False)
65+
for attribute_name in generic_node.attribute_names:
66+
attribute = generic_node.get_attribute(name=attribute_name)
67+
if isinstance(attribute.parameters, NumberPoolParameters) and attribute.parameters.number_pool_id:
68+
if attribute.parameters.number_pool_id not in existing_pool_ids:
69+
await _create_number_pool(
70+
service=service,
71+
number_pool_id=attribute.parameters.number_pool_id,
72+
pool_node=generic_node.kind,
73+
pool_attribute=attribute_name,
74+
start_range=attribute.parameters.start_range,
75+
end_range=attribute.parameters.end_range,
76+
)
77+
existing_pool_ids.append(attribute.parameters.number_pool_id)
78+
79+
for node_name in schema_branch.node_names:
80+
node = schema_branch.get_node(name=node_name, duplicate=False)
81+
for attribute_name in node.attribute_names:
82+
attribute = node.get_attribute(name=attribute_name)
83+
if isinstance(attribute.parameters, NumberPoolParameters) and attribute.parameters.number_pool_id:
84+
if attribute.parameters.number_pool_id not in existing_pool_ids:
85+
await _create_number_pool(
86+
service=service,
87+
number_pool_id=attribute.parameters.number_pool_id,
88+
pool_node=node.kind,
89+
pool_attribute=attribute_name,
90+
start_range=attribute.parameters.start_range,
91+
end_range=attribute.parameters.end_range,
92+
)
93+
existing_pool_ids.append(attribute.parameters.number_pool_id)
94+
95+
96+
async def _create_number_pool(
97+
service: InfrahubServices,
98+
number_pool_id: str,
99+
pool_node: str,
100+
pool_attribute: str,
101+
start_range: int,
102+
end_range: int,
103+
) -> None:
104+
async with service.database.start_session() as dbs:
105+
number_pool = await Node.init(db=dbs, schema=InfrahubKind.NUMBERPOOL, branch=registry.default_branch)
106+
await number_pool.new(
107+
db=dbs,
108+
id=number_pool_id,
109+
name=f"{pool_node}.{pool_attribute} [{number_pool_id}]",
110+
node=pool_node,
111+
node_attribute=pool_attribute,
112+
start_range=start_range,
113+
end_range=end_range,
114+
pool_type=NumberPoolType.SCHEMA.value,
115+
)
116+
await number_pool.save(db=dbs)

backend/tests/functional/pools/test_numberpool_lifecycle.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from infrahub.pools.tasks import validate_schema_number_pools
1717
from infrahub.services import InfrahubServices
1818
from infrahub.services.adapters.cache.redis import RedisCache
19+
from tests.helpers.schema.snow import SNOW_INCIDENT, SNOW_REQUEST, SNOW_TASK
1920
from tests.helpers.test_app import TestInfrahubApp
2021

2122
if TYPE_CHECKING:
@@ -43,7 +44,7 @@
4344
}
4445

4546

46-
class TestMutationGenerator(TestInfrahubApp):
47+
class TestAttributeNumberPoolLifecycle(TestInfrahubApp):
4748
@pytest.fixture(scope="class")
4849
async def initial_dataset(
4950
self,
@@ -56,18 +57,24 @@ async def initial_dataset(
5657
) -> None:
5758
bus_simulator.service._cache = RedisCache()
5859

59-
schema = {"version": "1.0", "nodes": [node_schema_definition]}
60+
schema = {
61+
"version": "1.0",
62+
"generics": [SNOW_TASK.to_dict()],
63+
"nodes": [node_schema_definition, SNOW_INCIDENT.to_dict(), SNOW_REQUEST.to_dict()],
64+
}
6065
schema_load_response = await client.schema.load(schemas=[schema], wait_until_converged=True)
6166
assert not schema_load_response.errors
6267

63-
async def test_numberpool_assignment(
68+
async def test_numberpool_assignment_direct_node(
6469
self, db: InfrahubDatabase, initial_dataset: None, client: InfrahubClient, default_branch
6570
) -> None:
66-
assert True
71+
service = await InfrahubServices.new(database=db)
72+
context = InfrahubContext.init(
73+
branch=default_branch,
74+
account=AccountSession(auth_type=AuthType.NONE, authenticated=False, account_id=""),
75+
)
6776

68-
incident_1 = await Node.init(db=db, schema="TestNumberAttribute")
69-
await incident_1.new(db=db, name="The first thing")
70-
await incident_1.save(db=db)
77+
await validate_schema_number_pools(branch_name=registry.default_branch, context=context, service=service)
7178

7279
test_schema = registry.schema.get_node_schema(name="TestNumberAttribute")
7380
test_attribute = test_schema.get_attribute(name="assigned_number")
@@ -81,6 +88,10 @@ async def test_numberpool_assignment(
8188
)
8289
assert number_pool_pre.start_range.value == 10
8390

91+
incident_1 = await Node.init(db=db, schema="TestNumberAttribute")
92+
await incident_1.new(db=db, name="The first thing")
93+
await incident_1.save(db=db)
94+
8495
initial_branches = get_branches_with_schema_number_pool(
8596
kind="TestNumberAttribute", attribute_name="assigned_number"
8697
)
@@ -94,6 +105,18 @@ async def test_numberpool_assignment(
94105
after_purge = get_branches_with_schema_number_pool(kind="TestNumberAttribute", attribute_name="assigned_number")
95106
assert after_purge == []
96107

108+
await validate_schema_number_pools(branch_name=registry.default_branch, context=context, service=service)
109+
110+
with pytest.raises(NodeNotFoundError):
111+
await NodeManager.find_object(
112+
db=db,
113+
kind=CoreNumberPool,
114+
id=number_pool_id,
115+
)
116+
117+
async def test_numberpool_assignment_from_generic(
118+
self, db: InfrahubDatabase, initial_dataset: None, client: InfrahubClient, default_branch
119+
) -> None:
97120
service = await InfrahubServices.new(database=db)
98121
context = InfrahubContext.init(
99122
branch=default_branch,
@@ -102,6 +125,40 @@ async def test_numberpool_assignment(
102125

103126
await validate_schema_number_pools(branch_name=registry.default_branch, context=context, service=service)
104127

128+
test_schema = registry.schema.get_node_schema(name="SnowIncident")
129+
test_attribute = test_schema.get_attribute(name="number")
130+
assert isinstance(test_attribute.parameters, NumberPoolParameters)
131+
number_pool_id = test_attribute.parameters.number_pool_id
132+
133+
number_pool_pre = await NodeManager.find_object(
134+
db=db,
135+
kind=CoreNumberPool,
136+
id=number_pool_id,
137+
)
138+
assert number_pool_pre.start_range.value == 1
139+
140+
incident_1 = await Node.init(db=db, schema="SnowIncident")
141+
await incident_1.new(db=db, title="The very first incident")
142+
await incident_1.save(db=db)
143+
144+
initial_branches = get_branches_with_schema_number_pool(kind="SnowTask", attribute_name="number")
145+
146+
assert initial_branches == ["main"]
147+
snow_task = SNOW_TASK.to_dict()
148+
snow_task["state"] = "absent"
149+
snow_request = SNOW_REQUEST.to_dict()
150+
snow_request["state"] = "absent"
151+
snow_incident = SNOW_INCIDENT.to_dict()
152+
snow_incident["state"] = "absent"
153+
schema = {"version": "1.0", "generics": [snow_task], "nodes": [snow_request, snow_incident]}
154+
schema_load_response = await client.schema.load(schemas=[schema], wait_until_converged=True)
155+
assert not schema_load_response.errors
156+
157+
after_purge = get_branches_with_schema_number_pool(kind="SnowTask", attribute_name="number")
158+
assert after_purge == []
159+
160+
await validate_schema_number_pools(branch_name=registry.default_branch, context=context, service=service)
161+
105162
with pytest.raises(NodeNotFoundError):
106163
await NodeManager.find_object(
107164
db=db,

0 commit comments

Comments
 (0)