Skip to content

Commit 4b4f5f1

Browse files
committed
Fix deployment tests
1 parent 3b8e629 commit 4b4f5f1

File tree

5 files changed

+59
-66
lines changed

5 files changed

+59
-66
lines changed

datajunction-server/datajunction_server/database/column.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def copy(self) -> "Column":
164164
"""
165165
Returns a full copy of the column
166166
"""
167-
print("self.type", self.name, self.type)
168167
return Column(
169168
order=self.order,
170169
name=self.name,

datajunction-server/datajunction_server/internal/deployment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ async def check_external_deps(
455455
if len(external_node_deps) != len(deps_not_in_deployment):
456456
missing_nodes = sorted(
457457
set(deps_not_in_deployment)
458-
- {node.name for node in external_node_deps},
458+
- {node.rendered_name for node in external_node_deps},
459459
)
460460
raise DJInvalidDeploymentConfig(
461461
message=(

datajunction-server/datajunction_server/models/deployment.py

Lines changed: 57 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from enum import Enum
2-
from pydantic import BaseModel, Field, field_validator, PrivateAttr, ConfigDict
2+
from pydantic import (
3+
BaseModel,
4+
Field,
5+
field_validator,
6+
PrivateAttr,
7+
ConfigDict,
8+
model_validator,
9+
)
310

4-
from typing import Any, Literal, Union
11+
from typing import Annotated, Any, Literal, Union
512

613
from datajunction_server.models.partition import Granularity, PartitionType
714
from datajunction_server.errors import DJInvalidInputException
@@ -40,8 +47,8 @@ class PartitionSpec(BaseModel):
4047
"""
4148

4249
type: PartitionType
43-
granularity: Granularity | None
44-
format: str | None
50+
granularity: Granularity | None = None
51+
format: str | None = None
4552

4653

4754
class ColumnSpec(BaseModel):
@@ -91,7 +98,7 @@ class DimensionJoinLinkSpec(DimensionLinkSpec):
9198
"""
9299

93100
dimension_node: str
94-
type: LinkType = LinkType.JOIN
101+
type: Literal[LinkType.JOIN] = LinkType.JOIN
95102

96103
node_column: str | None = None
97104
join_type: JoinType = JoinType.LEFT
@@ -147,7 +154,7 @@ class DimensionReferenceLinkSpec(DimensionLinkSpec):
147154

148155
node_column: str
149156
dimension: str
150-
type: LinkType = LinkType.REFERENCE
157+
type: Literal[LinkType.REFERENCE] = LinkType.REFERENCE
151158

152159
@property
153160
def rendered_dimension_node(self) -> str:
@@ -262,31 +269,23 @@ class LinkableNodeSpec(NodeSpec):
262269
"""
263270

264271
columns: list[ColumnSpec] | None = None
265-
dimension_links: list[DimensionJoinLinkSpec | DimensionReferenceLinkSpec] = Field(
266-
default_factory=list,
267-
)
272+
dimension_links: list[
273+
Annotated[
274+
DimensionJoinLinkSpec | DimensionReferenceLinkSpec,
275+
Field(discriminator="type"),
276+
]
277+
] = Field(default_factory=list)
268278
primary_key: list[str] = Field(default_factory=list)
269279

270-
@field_validator("dimension_links", mode="before")
271-
def coerce_dimension_links(cls, values):
272-
if not isinstance(values, list):
273-
values = [values]
274-
result = []
275-
for value in values:
276-
if isinstance(value, dict):
277-
link_type = value.get("type")
278-
mapping = {
279-
"join": DimensionJoinLinkSpec,
280-
"reference": DimensionReferenceLinkSpec,
281-
}
282-
if link_type not in mapping: # pragma: no cover
283-
raise ValueError(f"Unknown link type: {link_type}")
284-
deployment_ns = getattr(cls, "namespace", None)
285-
result.append(mapping[link_type](**value, namespace=deployment_ns))
286-
else:
287-
value.namespace = getattr(cls, "namespace", None)
288-
result.append(value)
289-
return result
280+
@model_validator(mode="after")
281+
def set_namespaces(self):
282+
"""
283+
Set namespace on all dimension links
284+
"""
285+
if self.namespace:
286+
for link in self.dimension_links:
287+
link.namespace = self.namespace
288+
return self
290289

291290
def __eq__(self, other: Any) -> bool:
292291
if not isinstance(other, LinkableNodeSpec):
@@ -365,11 +364,11 @@ class MetricSpec(NodeSpec):
365364
query: str
366365
required_dimensions: list[str] | None = None
367366
direction: MetricDirection | None = None
368-
unit_enum: MetricUnit | None = Field(None, exclude=True)
367+
unit_enum: MetricUnit | None = Field(default=None, exclude=True)
369368

370369
significant_digits: int | None = None
371-
min_decimal_exponent: int | None
372-
max_decimal_exponent: int | None
370+
min_decimal_exponent: int | None = None
371+
max_decimal_exponent: int | None = None
373372

374373
def __init__(self, **data: Any):
375374
unit = data.pop("unit", None)
@@ -445,12 +444,15 @@ def __eq__(self, other: Any) -> bool:
445444
)
446445

447446

448-
NodeUnion = Union[
449-
SourceSpec,
450-
TransformSpec,
451-
DimensionSpec,
452-
MetricSpec,
453-
CubeSpec,
447+
NodeUnion = Annotated[
448+
Union[
449+
SourceSpec,
450+
TransformSpec,
451+
DimensionSpec,
452+
MetricSpec,
453+
CubeSpec,
454+
],
455+
Field(discriminator="node_type"),
454456
]
455457

456458

@@ -488,28 +490,23 @@ class DeploymentSpec(BaseModel):
488490
nodes: list[NodeUnion] = Field(default_factory=list)
489491
tags: list[TagSpec] = Field(default_factory=list)
490492

491-
@field_validator("nodes", mode="before")
492-
def coerce_nodes(cls, values):
493-
if not isinstance(values, list):
494-
values = [values]
495-
result = []
496-
for value in values:
497-
if isinstance(value, dict):
498-
node_type = value.get("node_type")
499-
mapping = {
500-
"source": SourceSpec,
501-
"transform": TransformSpec,
502-
"dimension": DimensionSpec,
503-
"metric": MetricSpec,
504-
"cube": CubeSpec,
505-
}
506-
if node_type not in mapping: # pragma: no cover
507-
raise ValueError(f"Unknown node_type: {node_type}")
508-
deployment_ns = getattr(cls, "namespace", None)
509-
result.append(mapping[node_type](**value, namespace=deployment_ns))
510-
else:
511-
result.append(value)
512-
return result
493+
@model_validator(mode="after")
494+
def set_namespaces(self):
495+
"""
496+
Set namespace on all node specs and their dimension links
497+
"""
498+
if hasattr(self, "nodes") and hasattr(self, "namespace") and self.namespace:
499+
for node in self.nodes:
500+
# Set namespace on the node itself
501+
if hasattr(node, "namespace") and not node.namespace:
502+
node.namespace = self.namespace
503+
504+
# Set namespace on dimension links (for LinkableNodeSpec subclasses)
505+
if hasattr(node, "dimension_links") and node.dimension_links:
506+
for link in node.dimension_links:
507+
if not link.namespace:
508+
link.namespace = self.namespace
509+
return self
513510

514511

515512
class VersionedNode(BaseModel):

datajunction-server/tests/api/data_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,6 @@ async def test_raising_when_node_does_not_exist(
995995
},
996996
)
997997
data = response.json()
998-
print("response!!", data)
999998

1000999
assert response.status_code == 404
10011000
assert data == {

datajunction-server/tests/api/nodes_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,8 +1571,7 @@ async def test_register_view_with_query_service(
15711571
"""
15721572
Registering a view with a query service set up should succeed.
15731573
"""
1574-
res = await module__client_with_basic.get("/catalogs")
1575-
print("catalogs!", res.json())
1574+
await module__client_with_basic.get("/catalogs")
15761575
response = await module__client_with_basic.post(
15771576
"/register/view/public/main/view_foo?query=SELECT+1+AS+one+,+'two'+AS+two",
15781577
)
@@ -2072,7 +2071,6 @@ async def test_create_update_source_node(
20722071
},
20732072
)
20742073
data = response.json()
2075-
print("data!!", data)
20762074

20772075
assert data["name"] == "basic.source.comments"
20782076
assert data["display_name"] == "Comments facts"

0 commit comments

Comments
 (0)