Skip to content

Commit a9ce05f

Browse files
committed
Migrate datajunction-server to pydantic v2
1 parent eef7f00 commit a9ce05f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+2045
-1804
lines changed

datajunction-server/datajunction_server/api/access/authentication/whoami.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def whoami(
3131
Returns the current authenticated user
3232
"""
3333
user = await User.get_by_username(session, current_user.username)
34-
return UserOutput.from_orm(user)
34+
return UserOutput.model_validate(user, from_attributes=True)
3535

3636

3737
@router.get("/token/")

datajunction-server/datajunction_server/api/attributes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ async def list_attributes(
3535
List all available attribute types.
3636
"""
3737
attributes = await AttributeType.get_all(session)
38-
return [AttributeTypeBase.from_orm(attr) for attr in attributes]
38+
return [
39+
AttributeTypeBase.model_validate(attr, from_attributes=True)
40+
for attr in attributes
41+
]
3942

4043

4144
@router.post(
@@ -62,7 +65,7 @@ async def add_attribute_type(
6265
message=f"Attribute type `{data.name}` already exists!",
6366
)
6467
attribute_type = await AttributeType.create(session, data)
65-
return AttributeTypeBase.from_orm(attribute_type)
68+
return AttributeTypeBase.model_validate(attribute_type, from_attributes=True)
6669

6770

6871
async def default_attribute_types(session: AsyncSession = Depends(get_session)):

datajunction-server/datajunction_server/api/catalogs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def list_catalogs(
3636
"""
3737
statement = select(Catalog).options(joinedload(Catalog.engines))
3838
return [
39-
CatalogInfo.from_orm(catalog)
39+
CatalogInfo.model_validate(catalog, from_attributes=True)
4040
for catalog in (await session.execute(statement)).unique().scalars()
4141
]
4242

@@ -111,7 +111,7 @@ async def add_catalog(
111111
await session.commit()
112112
await session.refresh(catalog, ["engines"])
113113

114-
return CatalogInfo.from_orm(catalog)
114+
return CatalogInfo.model_validate(catalog, from_attributes=True)
115115

116116

117117
@router.post(
@@ -136,7 +136,7 @@ async def add_engines_to_catalog(
136136
session.add(catalog)
137137
await session.commit()
138138
await session.refresh(catalog)
139-
return CatalogInfo.from_orm(catalog)
139+
return CatalogInfo.model_validate(catalog, from_attributes=True)
140140

141141

142142
async def list_new_engines(

datajunction-server/datajunction_server/api/collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def create_a_collection(
6161
await session.commit()
6262
await session.refresh(collection)
6363

64-
return CollectionInfo.from_orm(collection)
64+
return CollectionInfo.model_validate(collection, from_attributes=True)
6565

6666

6767
@router.delete("/collections/{name}", status_code=HTTPStatus.NO_CONTENT)

datajunction-server/datajunction_server/api/cubes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ async def cube_materialization_info(
159159
Granularity.YEAR: "0 0 1 1 *", # Runs at midnight on January 1st every year
160160
}
161161
upsert = UpsertCubeMaterialization(
162-
job=MaterializationJobTypeEnum.DRUID_CUBE,
162+
job=MaterializationJobTypeEnum.DRUID_CUBE.value.name,
163163
strategy=(
164164
MaterializationStrategy.INCREMENTAL_TIME
165165
if temporal_partition
@@ -186,7 +186,7 @@ async def cube_materialization_info(
186186
metrics=cube_config.metrics,
187187
strategy=upsert.strategy,
188188
schedule=upsert.schedule,
189-
job=upsert.job.name,
189+
job=upsert.job.name, # type: ignore
190190
measures_materializations=cube_config.measures_materializations,
191191
combiners=cube_config.combiners,
192192
)
@@ -299,7 +299,7 @@ async def get_cube_dimension_values(
299299
value=row[0 : count_column[0]] if count_column else row,
300300
count=row[count_column[0]] if count_column else None,
301301
)
302-
for row in result.results.__root__[0].rows
302+
for row in result.results.root[0].rows
303303
]
304304
return DimensionValues( # pragma: no cover
305305
dimensions=[

datajunction-server/datajunction_server/api/data.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,15 @@ async def add_availability_state(
141141
table=data.table,
142142
valid_through_ts=data.valid_through_ts,
143143
url=data.url,
144-
min_temporal_partition=data.min_temporal_partition,
145-
max_temporal_partition=data.max_temporal_partition,
144+
min_temporal_partition=[
145+
str(part) for part in data.min_temporal_partition or []
146+
],
147+
max_temporal_partition=[
148+
str(part) for part in data.max_temporal_partition or []
149+
],
146150
partitions=[
147-
partition.dict() if not isinstance(partition, Dict) else partition
148-
for partition in data.partitions # type: ignore
151+
partition.model_dump() if not isinstance(partition, Dict) else partition
152+
for partition in (data.partitions or [])
149153
],
150154
categorical_partitions=data.categorical_partitions,
151155
temporal_partitions=data.temporal_partitions,
@@ -159,10 +163,14 @@ async def add_availability_state(
159163
entity_type=EntityType.AVAILABILITY,
160164
node=node.name, # type: ignore
161165
activity_type=ActivityType.CREATE,
162-
pre=AvailabilityStateBase.from_orm(old_availability).dict()
166+
pre=AvailabilityStateBase.model_validate(
167+
old_availability,
168+
).model_dump()
163169
if old_availability
164170
else {},
165-
post=AvailabilityStateBase.from_orm(node_revision.availability).dict(),
171+
post=AvailabilityStateBase.model_validate(
172+
node_revision.availability,
173+
).model_dump(),
166174
user=current_user.username,
167175
),
168176
session=session,
@@ -262,8 +270,8 @@ async def get_data(
262270
)
263271

264272
# Inject column info if there are results
265-
if result.results.__root__: # pragma: no cover
266-
result.results.__root__[0].columns = generated_sql.columns # type: ignore
273+
if result.results.root: # pragma: no cover
274+
result.results.root[0].columns = generated_sql.columns # type: ignore
267275
return result
268276

269277

@@ -447,8 +455,8 @@ async def get_data_for_metrics(
447455
)
448456

449457
# Inject column info if there are results
450-
if result.results.__root__: # pragma: no cover
451-
result.results.__root__[0].columns = translated_sql.columns or []
458+
if result.results.root: # pragma: no cover
459+
result.results.root[0].columns = translated_sql.columns or []
452460
return result
453461

454462

datajunction-server/datajunction_server/api/deployments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ async def submit(self, spec: DeploymentSpec, context: DeploymentContext) -> str:
7777
deployment = Deployment(
7878
uuid=deployment_uuid,
7979
namespace=spec.namespace,
80-
spec=spec.dict(),
80+
spec=spec.model_dump(),
8181
status=DeploymentStatus.PENDING,
8282
created_by_id=context.current_user.id,
8383
)
@@ -109,7 +109,7 @@ async def update_status(
109109
deployment = await session.get(Deployment, deployment_uuid)
110110
deployment.status = status
111111
if results is not None:
112-
deployment.results = [r.dict() for r in results]
112+
deployment.results = [r.model_dump() for r in results]
113113
await session.commit()
114114

115115
async def _run_deployment(

datajunction-server/datajunction_server/api/djsql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ async def get_data_for_djsql(
7272
)
7373

7474
# Inject column info if there are results
75-
if result.results.__root__: # pragma: no cover
76-
result.results.__root__[0].columns = translated_sql.columns or []
75+
if result.results.root: # pragma: no cover
76+
result.results.root[0].columns = translated_sql.columns or []
7777
return result
7878

7979

datajunction-server/datajunction_server/api/engines.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def list_engines(
4343
List all available engines
4444
"""
4545
return [
46-
EngineInfo.from_orm(engine)
46+
EngineInfo.model_validate(engine)
4747
for engine in (await session.execute(select(Engine))).scalars()
4848
]
4949

@@ -58,7 +58,9 @@ async def get_an_engine(
5858
"""
5959
Return an engine by name and version
6060
"""
61-
return EngineInfo.from_orm(await get_engine(session, name, version))
61+
return EngineInfo.model_validate(
62+
await get_engine(session, name, version),
63+
)
6264

6365

6466
@router.post(
@@ -95,4 +97,4 @@ async def add_engine(
9597
await session.commit()
9698
await session.refresh(engine)
9799

98-
return EngineInfo.from_orm(engine)
100+
return EngineInfo.model_validate(engine)

datajunction-server/datajunction_server/api/helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,8 @@ async def query_event_stream(
681681
"query end state detected (%s), sending final event to the client",
682682
query_next.state,
683683
)
684-
if query_next.results.__root__: # pragma: no cover
685-
query_next.results.__root__[0].columns = columns or []
684+
if query_next.results.root: # pragma: no cover
685+
query_next.results.root[0].columns = columns or []
686686
yield {
687687
"event": "message",
688688
"id": uuid.uuid4(),
@@ -879,13 +879,13 @@ def get_node_revision_materialization(
879879
)
880880
if materialization.strategy != MaterializationStrategy.INCREMENTAL_TIME:
881881
info.urls = [info.urls[0]]
882-
materialization_config_output = MaterializationConfigOutput.from_orm(
882+
materialization_config_output = MaterializationConfigOutput.model_validate(
883883
materialization,
884884
)
885885
materializations.append(
886886
MaterializationConfigInfoUnified(
887-
**materialization_config_output.dict(),
888-
**info.dict(),
887+
**materialization_config_output.model_dump(),
888+
**info.model_dump(),
889889
),
890890
)
891891
return materializations

0 commit comments

Comments
 (0)