|
1 | 1 | from enum import Enum |
2 | | -from pydantic import BaseModel, Field, field_validator, PrivateAttr, ConfigDict |
| 2 | +from pydantic import ( |
| 3 | + BaseModel, |
| 4 | + Field, |
| 5 | + field_validator, |
| 6 | + PrivateAttr, |
| 7 | + ConfigDict, |
| 8 | + model_validator, |
| 9 | +) |
3 | 10 |
|
4 | | -from typing import Any, Literal, Union |
| 11 | +from typing import Annotated, Any, Literal, Union |
5 | 12 |
|
6 | 13 | from datajunction_server.models.partition import Granularity, PartitionType |
7 | 14 | from datajunction_server.errors import DJInvalidInputException |
@@ -40,8 +47,8 @@ class PartitionSpec(BaseModel): |
40 | 47 | """ |
41 | 48 |
|
42 | 49 | type: PartitionType |
43 | | - granularity: Granularity | None |
44 | | - format: str | None |
| 50 | + granularity: Granularity | None = None |
| 51 | + format: str | None = None |
45 | 52 |
|
46 | 53 |
|
47 | 54 | class ColumnSpec(BaseModel): |
@@ -91,7 +98,7 @@ class DimensionJoinLinkSpec(DimensionLinkSpec): |
91 | 98 | """ |
92 | 99 |
|
93 | 100 | dimension_node: str |
94 | | - type: LinkType = LinkType.JOIN |
| 101 | + type: Literal[LinkType.JOIN] = LinkType.JOIN |
95 | 102 |
|
96 | 103 | node_column: str | None = None |
97 | 104 | join_type: JoinType = JoinType.LEFT |
@@ -147,7 +154,7 @@ class DimensionReferenceLinkSpec(DimensionLinkSpec): |
147 | 154 |
|
148 | 155 | node_column: str |
149 | 156 | dimension: str |
150 | | - type: LinkType = LinkType.REFERENCE |
| 157 | + type: Literal[LinkType.REFERENCE] = LinkType.REFERENCE |
151 | 158 |
|
152 | 159 | @property |
153 | 160 | def rendered_dimension_node(self) -> str: |
@@ -262,31 +269,23 @@ class LinkableNodeSpec(NodeSpec): |
262 | 269 | """ |
263 | 270 |
|
264 | 271 | columns: list[ColumnSpec] | None = None |
265 | | - dimension_links: list[DimensionJoinLinkSpec | DimensionReferenceLinkSpec] = Field( |
266 | | - default_factory=list, |
267 | | - ) |
| 272 | + dimension_links: list[ |
| 273 | + Annotated[ |
| 274 | + DimensionJoinLinkSpec | DimensionReferenceLinkSpec, |
| 275 | + Field(discriminator="type"), |
| 276 | + ] |
| 277 | + ] = Field(default_factory=list) |
268 | 278 | primary_key: list[str] = Field(default_factory=list) |
269 | 279 |
|
270 | | - @field_validator("dimension_links", mode="before") |
271 | | - def coerce_dimension_links(cls, values): |
272 | | - if not isinstance(values, list): |
273 | | - values = [values] |
274 | | - result = [] |
275 | | - for value in values: |
276 | | - if isinstance(value, dict): |
277 | | - link_type = value.get("type") |
278 | | - mapping = { |
279 | | - "join": DimensionJoinLinkSpec, |
280 | | - "reference": DimensionReferenceLinkSpec, |
281 | | - } |
282 | | - if link_type not in mapping: # pragma: no cover |
283 | | - raise ValueError(f"Unknown link type: {link_type}") |
284 | | - deployment_ns = getattr(cls, "namespace", None) |
285 | | - result.append(mapping[link_type](**value, namespace=deployment_ns)) |
286 | | - else: |
287 | | - value.namespace = getattr(cls, "namespace", None) |
288 | | - result.append(value) |
289 | | - return result |
| 280 | + @model_validator(mode="after") |
| 281 | + def set_namespaces(self): |
| 282 | + """ |
| 283 | + Set namespace on all dimension links |
| 284 | + """ |
| 285 | + if self.namespace: |
| 286 | + for link in self.dimension_links: |
| 287 | + link.namespace = self.namespace |
| 288 | + return self |
290 | 289 |
|
291 | 290 | def __eq__(self, other: Any) -> bool: |
292 | 291 | if not isinstance(other, LinkableNodeSpec): |
@@ -365,11 +364,11 @@ class MetricSpec(NodeSpec): |
365 | 364 | query: str |
366 | 365 | required_dimensions: list[str] | None = None |
367 | 366 | direction: MetricDirection | None = None |
368 | | - unit_enum: MetricUnit | None = Field(None, exclude=True) |
| 367 | + unit_enum: MetricUnit | None = Field(default=None, exclude=True) |
369 | 368 |
|
370 | 369 | significant_digits: int | None = None |
371 | | - min_decimal_exponent: int | None |
372 | | - max_decimal_exponent: int | None |
| 370 | + min_decimal_exponent: int | None = None |
| 371 | + max_decimal_exponent: int | None = None |
373 | 372 |
|
374 | 373 | def __init__(self, **data: Any): |
375 | 374 | unit = data.pop("unit", None) |
@@ -445,12 +444,15 @@ def __eq__(self, other: Any) -> bool: |
445 | 444 | ) |
446 | 445 |
|
447 | 446 |
|
448 | | -NodeUnion = Union[ |
449 | | - SourceSpec, |
450 | | - TransformSpec, |
451 | | - DimensionSpec, |
452 | | - MetricSpec, |
453 | | - CubeSpec, |
| 447 | +NodeUnion = Annotated[ |
| 448 | + Union[ |
| 449 | + SourceSpec, |
| 450 | + TransformSpec, |
| 451 | + DimensionSpec, |
| 452 | + MetricSpec, |
| 453 | + CubeSpec, |
| 454 | + ], |
| 455 | + Field(discriminator="node_type"), |
454 | 456 | ] |
455 | 457 |
|
456 | 458 |
|
@@ -488,28 +490,23 @@ class DeploymentSpec(BaseModel): |
488 | 490 | nodes: list[NodeUnion] = Field(default_factory=list) |
489 | 491 | tags: list[TagSpec] = Field(default_factory=list) |
490 | 492 |
|
491 | | - @field_validator("nodes", mode="before") |
492 | | - def coerce_nodes(cls, values): |
493 | | - if not isinstance(values, list): |
494 | | - values = [values] |
495 | | - result = [] |
496 | | - for value in values: |
497 | | - if isinstance(value, dict): |
498 | | - node_type = value.get("node_type") |
499 | | - mapping = { |
500 | | - "source": SourceSpec, |
501 | | - "transform": TransformSpec, |
502 | | - "dimension": DimensionSpec, |
503 | | - "metric": MetricSpec, |
504 | | - "cube": CubeSpec, |
505 | | - } |
506 | | - if node_type not in mapping: # pragma: no cover |
507 | | - raise ValueError(f"Unknown node_type: {node_type}") |
508 | | - deployment_ns = getattr(cls, "namespace", None) |
509 | | - result.append(mapping[node_type](**value, namespace=deployment_ns)) |
510 | | - else: |
511 | | - result.append(value) |
512 | | - return result |
| 493 | + @model_validator(mode="after") |
| 494 | + def set_namespaces(self): |
| 495 | + """ |
| 496 | + Set namespace on all node specs and their dimension links |
| 497 | + """ |
| 498 | + if hasattr(self, "nodes") and hasattr(self, "namespace") and self.namespace: |
| 499 | + for node in self.nodes: |
| 500 | + # Set namespace on the node itself |
| 501 | + if hasattr(node, "namespace") and not node.namespace: |
| 502 | + node.namespace = self.namespace |
| 503 | + |
| 504 | + # Set namespace on dimension links (for LinkableNodeSpec subclasses) |
| 505 | + if hasattr(node, "dimension_links") and node.dimension_links: |
| 506 | + for link in node.dimension_links: |
| 507 | + if not link.namespace: |
| 508 | + link.namespace = self.namespace |
| 509 | + return self |
513 | 510 |
|
514 | 511 |
|
515 | 512 | class VersionedNode(BaseModel): |
|
0 commit comments