Skip to content

Commit 4e9a897

Browse files
Adds derivation type unspecified and exposes derivation_type for get_entity_derivations (#123)
* Adds derivation type unspecified. * Exposes derivation_type for get_entity_derivations
1 parent 07e2a41 commit 4e9a897

File tree

4 files changed

+46
-25
lines changed

4 files changed

+46
-25
lines changed

src/entitysdk/_server_schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ class ContributionCreate(BaseModel):
213213
class DerivationType(StrEnum):
214214
circuit_extraction = "circuit_extraction"
215215
circuit_rewiring = "circuit_rewiring"
216+
unspecified = "unspecified"
216217

217218

218219
class DetailedFile(BaseModel):

src/entitysdk/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
AssetLabel,
2828
ContentType,
2929
DeploymentEnvironment,
30+
DerivationType,
3031
StorageType,
3132
Token,
3233
)
@@ -169,14 +170,15 @@ def get_entity_derivations(
169170
*,
170171
entity_id: ID,
171172
entity_type: type[Entity],
173+
derivation_type: DerivationType,
172174
project_context: ProjectContext | None = None,
173175
) -> IteratorResult[Entity]:
174176
"""Get all the derivation for an entity."""
175177
return core.get_entity_derivations(
176178
api_url=self.api_url,
177179
entity_id=entity_id,
178180
entity_type=entity_type,
179-
derivation_type=None,
181+
derivation_type=derivation_type,
180182
project_context=self._required_user_context(override_context=project_context),
181183
token=self._token_manager.get_token(),
182184
)

src/entitysdk/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_entity_derivations(
9595
entity_id: ID,
9696
entity_type: type[Entity],
9797
project_context: ProjectContext,
98-
derivation_type: DerivationType | None,
98+
derivation_type: DerivationType,
9999
token: str,
100100
http_client: httpx.Client | None = None,
101101
) -> IteratorResult[Entity]:
@@ -106,7 +106,7 @@ def get_entity_derivations(
106106
entity_id=entity_id,
107107
)
108108

109-
params = {"derivation_type": DerivationType(derivation_type)} if derivation_type else None
109+
params = {"derivation_type": DerivationType(derivation_type)}
110110

111111
response = make_db_api_request(
112112
url=url,

tests/unit/test_client.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,30 +1121,48 @@ def test_client_get_entity_derivations(mock_route, client, httpx_mock, api_url,
11211121
used_id = uuid.uuid4()
11221122
generated_id = uuid.uuid4()
11231123

1124-
httpx_mock.add_response(
1125-
method="GET",
1126-
url=f"{api_url}/circuit/{entity_id}/derived-from",
1127-
match_headers=request_headers,
1128-
json={
1129-
"data": [
1130-
{
1131-
"id": str(derivation_1),
1132-
"used_id": str(used_id),
1133-
"generated_id": str(generated_id),
1134-
},
1135-
{
1136-
"id": str(derivation_2),
1137-
"used_id": str(used_id),
1138-
"generated_id": str(generated_id),
1139-
"derivation_type": DerivationType.circuit_extraction,
1140-
},
1141-
]
1142-
},
1143-
)
1124+
def add_response(derivation_type: DerivationType):
1125+
httpx_mock.add_response(
1126+
method="GET",
1127+
url=f"{api_url}/circuit/{entity_id}/derived-from?derivation_type={derivation_type}",
1128+
match_headers=request_headers,
1129+
json={
1130+
"data": [
1131+
{
1132+
"id": str(derivation_1),
1133+
"used_id": str(used_id),
1134+
"generated_id": str(generated_id),
1135+
"derivation_type": derivation_type,
1136+
},
1137+
{
1138+
"id": str(derivation_2),
1139+
"used_id": str(used_id),
1140+
"generated_id": str(generated_id),
1141+
"derivation_type": derivation_type,
1142+
},
1143+
]
1144+
},
1145+
)
11441146

1147+
add_response(DerivationType.circuit_extraction)
11451148
res = client.get_entity_derivations(
1146-
entity_id=entity_id,
1147-
entity_type=Circuit,
1149+
entity_id=entity_id, entity_type=Circuit, derivation_type=DerivationType.circuit_extraction
1150+
).all()
1151+
assert len(res) == 2
1152+
assert res[0].id == derivation_1
1153+
assert res[1].id == derivation_2
1154+
1155+
add_response(DerivationType.circuit_rewiring)
1156+
res = client.get_entity_derivations(
1157+
entity_id=entity_id, entity_type=Circuit, derivation_type=DerivationType.circuit_rewiring
1158+
).all()
1159+
assert len(res) == 2
1160+
assert res[0].id == derivation_1
1161+
assert res[1].id == derivation_2
1162+
1163+
add_response(DerivationType.unspecified)
1164+
res = client.get_entity_derivations(
1165+
entity_id=entity_id, entity_type=Circuit, derivation_type=DerivationType.unspecified
11481166
).all()
11491167
assert len(res) == 2
11501168
assert res[0].id == derivation_1

0 commit comments

Comments
 (0)