From bb5f1dfe1b13a54b083c7b78b4a5c0ae8c53d095 Mon Sep 17 00:00:00 2001 From: Donny Winston Date: Wed, 1 Nov 2023 11:29:03 -0400 Subject: [PATCH] migrate to pydantic v2 (#344) * fix: run `bump-pydantic nmdc_runtime` and apply closes #339 addresses #343 * fix: @model_validator refactor closes #343 --- nmdc_runtime/api/core/auth.py | 2 +- nmdc_runtime/api/models/capability.py | 2 +- nmdc_runtime/api/models/id.py | 21 +++--- nmdc_runtime/api/models/job.py | 4 +- nmdc_runtime/api/models/metadata.py | 5 +- nmdc_runtime/api/models/object.py | 75 ++++++++++--------- nmdc_runtime/api/models/object_type.py | 4 +- nmdc_runtime/api/models/operation.py | 26 +++---- nmdc_runtime/api/models/query.py | 45 +++++------ nmdc_runtime/api/models/run.py | 2 +- nmdc_runtime/api/models/util.py | 28 +++---- nmdc_runtime/api/models/workflow.py | 6 +- nmdc_runtime/api/v1/models/ingest.py | 6 +- .../database/impl/mongo/models/user.py | 8 +- nmdc_runtime/minter/domain/model.py | 4 +- nmdc_runtime/util.py | 5 +- 16 files changed, 125 insertions(+), 118 deletions(-) diff --git a/nmdc_runtime/api/core/auth.py b/nmdc_runtime/api/core/auth.py index 3614fb55..fe026162 100644 --- a/nmdc_runtime/api/core/auth.py +++ b/nmdc_runtime/api/core/auth.py @@ -35,7 +35,7 @@ class TokenExpires(BaseModel): class Token(BaseModel): access_token: str token_type: str - expires: Optional[TokenExpires] + expires: Optional[TokenExpires] = None class TokenData(BaseModel): diff --git a/nmdc_runtime/api/models/capability.py b/nmdc_runtime/api/models/capability.py index 95241d79..6de4a2ff 100644 --- a/nmdc_runtime/api/models/capability.py +++ b/nmdc_runtime/api/models/capability.py @@ -5,7 +5,7 @@ class CapabilityBase(BaseModel): - name: Optional[str] + name: Optional[str] = None description: Optional[str] = None diff --git a/nmdc_runtime/api/models/id.py b/nmdc_runtime/api/models/id.py index 34064065..687109a2 100644 --- a/nmdc_runtime/api/models/id.py +++ b/nmdc_runtime/api/models/id.py @@ -2,7 +2,8 @@ from enum import Enum from typing import Union, Any, Optional, Literal -from pydantic import BaseModel, constr, PositiveInt, root_validator +from pydantic import model_validator, StringConstraints, BaseModel, PositiveInt +from typing_extensions import Annotated # NO i, l, o or u. base32_letters = "abcdefghjkmnpqrstvwxyz" @@ -22,11 +23,11 @@ _base_object_name = f"{_naa}:{_shoulder}{_blade}" pattern_base_object_name = re.compile(_base_object_name) -Naa = constr(pattern=_naa) -Shoulder = constr(pattern=rf"^{_shoulder}$", min_length=2) -Blade = constr(pattern=_blade, min_length=4) -AssignedBaseName = constr(pattern=_assigned_base_name) -BaseObjectName = constr(pattern=_base_object_name) +Naa = Annotated[str, StringConstraints(pattern=_naa)] +Shoulder = Annotated[str, StringConstraints(pattern=rf"^{_shoulder}$", min_length=2)] +Blade = Annotated[str, StringConstraints(pattern=_blade, min_length=4)] +AssignedBaseName = Annotated[str, StringConstraints(pattern=_assigned_base_name)] +BaseObjectName = Annotated[str, StringConstraints(pattern=_base_object_name)] NameAssigningAuthority = Literal[tuple(NAA_VALUES)] @@ -71,10 +72,10 @@ class IdBindingOp(str, Enum): class IdBindingRequest(BaseModel): i: BaseObjectName o: IdBindingOp = IdBindingOp.set - a: Optional[str] - v: Any + a: Optional[str] = None + v: Any = None - @root_validator(skip_on_failure=True) + @model_validator(mode="before") def set_or_add_needs_value(cls, values): op = values.get("o") if op in (IdBindingOp.set, IdBindingOp.addToSet): @@ -82,7 +83,7 @@ def set_or_add_needs_value(cls, values): raise ValueError("{'set','add'} operations needs value 'v'.") return values - @root_validator(skip_on_failure=True) + @model_validator(mode="before") def set_or_add_or_rm_needs_attribute(cls, values): op = values.get("o") if op in (IdBindingOp.set, IdBindingOp.addToSet, IdBindingOp.rm): diff --git a/nmdc_runtime/api/models/job.py b/nmdc_runtime/api/models/job.py index 2c9fae63..e39c2877 100644 --- a/nmdc_runtime/api/models/job.py +++ b/nmdc_runtime/api/models/job.py @@ -9,7 +9,7 @@ class JobBase(BaseModel): workflow: Workflow - name: Optional[str] + name: Optional[str] = None description: Optional[str] = None @@ -20,7 +20,7 @@ class JobClaim(BaseModel): class Job(JobBase): id: str - created_at: Optional[datetime.datetime] + created_at: Optional[datetime.datetime] = None config: Dict[str, Any] claims: List[JobClaim] = [] diff --git a/nmdc_runtime/api/models/metadata.py b/nmdc_runtime/api/models/metadata.py index 7a5aabaa..5e70cba5 100644 --- a/nmdc_runtime/api/models/metadata.py +++ b/nmdc_runtime/api/models/metadata.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Extra +from pydantic import ConfigDict, BaseModel class ChangesheetIn(BaseModel): @@ -8,5 +8,4 @@ class ChangesheetIn(BaseModel): class Doc(BaseModel): - class Config: - extra = Extra.allow + model_config = ConfigDict(extra="allow") diff --git a/nmdc_runtime/api/models/object.py b/nmdc_runtime/api/models/object.py index 65281362..26af100c 100644 --- a/nmdc_runtime/api/models/object.py +++ b/nmdc_runtime/api/models/object.py @@ -5,14 +5,15 @@ from typing import Optional, List, Dict from pydantic import ( + field_validator, + model_validator, + Field, + StringConstraints, BaseModel, AnyUrl, - constr, - conint, HttpUrl, - root_validator, - validator, ) +from typing_extensions import Annotated class AccessMethodType(str, Enum): @@ -27,17 +28,17 @@ class AccessMethodType(str, Enum): class AccessURL(BaseModel): - headers: Optional[Dict[str, str]] + headers: Optional[Dict[str, str]] = None url: AnyUrl class AccessMethod(BaseModel): - access_id: Optional[constr(min_length=1)] - access_url: Optional[AccessURL] - region: Optional[str] + access_id: Optional[Annotated[str, StringConstraints(min_length=1)]] = None + access_url: Optional[AccessURL] = None + region: Optional[str] = None type: AccessMethodType = AccessMethodType.https - @root_validator(skip_on_failure=True) + @model_validator(mode="before") def at_least_one_of_access_id_and_url(cls, values): access_id, access_url = values.get("access_id"), values.get("access_url") if access_id is None and access_url is None: @@ -47,27 +48,30 @@ def at_least_one_of_access_id_and_url(cls, values): return values -ChecksumType = constr( - pattern=rf"(?P({'|'.join(sorted(hashlib.algorithms_guaranteed))}))" -) +ChecksumType = Annotated[ + str, + StringConstraints( + pattern=rf"(?P({'|'.join(sorted(hashlib.algorithms_guaranteed))}))" + ), +] class Checksum(BaseModel): - checksum: constr(min_length=1) + checksum: Annotated[str, StringConstraints(min_length=1)] type: ChecksumType -DrsId = constr(pattern=r"^[A-Za-z0-9._~\-]+$") -PortableFilename = constr(pattern=r"^[A-Za-z0-9._\-]+$") +DrsId = Annotated[str, StringConstraints(pattern=r"^[A-Za-z0-9._~\-]+$")] +PortableFilename = Annotated[str, StringConstraints(pattern=r"^[A-Za-z0-9._\-]+$")] class ContentsObject(BaseModel): - contents: Optional[List["ContentsObject"]] - drs_uri: Optional[List[AnyUrl]] - id: Optional[DrsId] + contents: Optional[List["ContentsObject"]] = None + drs_uri: Optional[List[AnyUrl]] = None + id: Optional[DrsId] = None name: PortableFilename - @root_validator(skip_on_failure=True) + @model_validator(mode="before") def no_contents_means_single_blob(cls, values): contents, id_ = values.get("contents"), values.get("id") if contents is None and id_ is None: @@ -77,32 +81,32 @@ def no_contents_means_single_blob(cls, values): ContentsObject.update_forward_refs() -Mimetype = constr(pattern=r"^\w+/[-+.\w]+$") -SizeInBytes = conint(strict=True, ge=0) +Mimetype = Annotated[str, StringConstraints(pattern=r"^\w+/[-+.\w]+$")] +SizeInBytes = Annotated[int, Field(strict=True, ge=0)] class Error(BaseModel): - msg: Optional[str] + msg: Optional[str] = None status_code: http.HTTPStatus class DrsObjectBase(BaseModel): - aliases: Optional[List[str]] + aliases: Optional[List[str]] = None description: Optional[str] = None - mime_type: Optional[Mimetype] - name: Optional[PortableFilename] + mime_type: Optional[Mimetype] = None + name: Optional[PortableFilename] = None class DrsObjectIn(DrsObjectBase): - access_methods: Optional[List[AccessMethod]] + access_methods: Optional[List[AccessMethod]] = None checksums: List[Checksum] - contents: Optional[List[ContentsObject]] + contents: Optional[List[ContentsObject]] = None created_time: datetime.datetime size: SizeInBytes - updated_time: Optional[datetime.datetime] - version: Optional[str] + updated_time: Optional[datetime.datetime] = None + version: Optional[str] = None - @root_validator(skip_on_failure=True) + @model_validator(mode="before") def no_contents_means_single_blob(cls, values): contents, access_methods = values.get("contents"), values.get("access_methods") if contents is None and access_methods is None: @@ -111,7 +115,8 @@ def no_contents_means_single_blob(cls, values): ) return values - @validator("checksums") + @field_validator("checksums") + @classmethod def at_least_one_checksum(cls, v): if not len(v) >= 1: raise ValueError("At least one checksum requried") @@ -123,7 +128,7 @@ class DrsObject(DrsObjectIn): self_uri: AnyUrl -Seconds = conint(strict=True, gt=0) +Seconds = Annotated[int, Field(strict=True, gt=0)] class ObjectPresignedUrl(BaseModel): @@ -137,8 +142,8 @@ class DrsObjectOutBase(DrsObjectBase): id: DrsId self_uri: AnyUrl size: SizeInBytes - updated_time: Optional[datetime.datetime] - version: Optional[str] + updated_time: Optional[datetime.datetime] = None + version: Optional[str] = None class DrsObjectBlobOut(DrsObjectOutBase): @@ -146,5 +151,5 @@ class DrsObjectBlobOut(DrsObjectOutBase): class DrsObjectBundleOut(DrsObjectOutBase): - access_methods: Optional[List[AccessMethod]] + access_methods: Optional[List[AccessMethod]] = None contents: List[ContentsObject] diff --git a/nmdc_runtime/api/models/object_type.py b/nmdc_runtime/api/models/object_type.py index b1c54cfa..fca6262a 100644 --- a/nmdc_runtime/api/models/object_type.py +++ b/nmdc_runtime/api/models/object_type.py @@ -7,7 +7,7 @@ class ObjectTypeBase(BaseModel): - name: Optional[str] + name: Optional[str] = None description: Optional[str] = None @@ -17,4 +17,4 @@ class ObjectType(ObjectTypeBase): class DrsObjectWithTypes(DrsObject): - types: Optional[List[str]] + types: Optional[List[str]] = None diff --git a/nmdc_runtime/api/models/operation.py b/nmdc_runtime/api/models/operation.py index 1e35df2a..e1819f24 100644 --- a/nmdc_runtime/api/models/operation.py +++ b/nmdc_runtime/api/models/operation.py @@ -1,44 +1,44 @@ import datetime from typing import Generic, TypeVar, Optional, List, Any, Union -from pydantic import BaseModel, HttpUrl, constr -from pydantic.generics import GenericModel +from pydantic import StringConstraints, BaseModel, HttpUrl from nmdc_runtime.api.models.util import ResultT +from typing_extensions import Annotated MetadataT = TypeVar("MetadataT") -PythonImportPath = constr(pattern=r"^[A-Za-z0-9_.]+$") +PythonImportPath = Annotated[str, StringConstraints(pattern=r"^[A-Za-z0-9_.]+$")] class OperationError(BaseModel): code: str message: str - details: Any + details: Any = None -class Operation(GenericModel, Generic[ResultT, MetadataT]): +class Operation(BaseModel, Generic[ResultT, MetadataT]): id: str done: bool = False expire_time: datetime.datetime - result: Optional[Union[ResultT, OperationError]] - metadata: Optional[MetadataT] + result: Optional[Union[ResultT, OperationError]] = None + metadata: Optional[MetadataT] = None -class UpdateOperationRequest(GenericModel, Generic[ResultT, MetadataT]): +class UpdateOperationRequest(BaseModel, Generic[ResultT, MetadataT]): done: bool = False - result: Optional[Union[ResultT, OperationError]] + result: Optional[Union[ResultT, OperationError]] = None metadata: Optional[MetadataT] = {} -class ListOperationsResponse(GenericModel, Generic[ResultT, MetadataT]): +class ListOperationsResponse(BaseModel, Generic[ResultT, MetadataT]): resources: List[Operation[ResultT, MetadataT]] - next_page_token: Optional[str] + next_page_token: Optional[str] = None class Result(BaseModel): - model: Optional[PythonImportPath] + model: Optional[PythonImportPath] = None class EmptyResult(Result): @@ -47,7 +47,7 @@ class EmptyResult(Result): class Metadata(BaseModel): # XXX alternative: set model field using __class__ on __init__()? - model: Optional[PythonImportPath] + model: Optional[PythonImportPath] = None class PausedOrNot(Metadata): diff --git a/nmdc_runtime/api/models/query.py b/nmdc_runtime/api/models/query.py index a99e54be..828f5a43 100644 --- a/nmdc_runtime/api/models/query.py +++ b/nmdc_runtime/api/models/query.py @@ -2,23 +2,24 @@ from typing import Optional, Any, Dict, List, Union from pydantic import ( + model_validator, + Field, BaseModel, - root_validator, - conint, PositiveInt, NonNegativeInt, ) +from typing_extensions import Annotated Document = Dict[str, Any] -OneOrZero = conint(ge=0, le=1) -One = conint(ge=1, le=1) -MinusOne = conint(ge=-1, le=-1) +OneOrZero = Annotated[int, Field(ge=0, le=1)] +One = Annotated[int, Field(ge=1, le=1)] +MinusOne = Annotated[int, Field(ge=-1, le=-1)] OneOrMinusOne = Union[One, MinusOne] class CommandBase(BaseModel): - comment: Optional[Any] + comment: Optional[Any] = None class CollStatsCommand(CommandBase): @@ -28,17 +29,17 @@ class CollStatsCommand(CommandBase): class CountCommand(CommandBase): count: str - query: Optional[Document] + query: Optional[Document] = None class FindCommand(CommandBase): find: str - filter: Optional[Document] - projection: Optional[Dict[str, OneOrZero]] + filter: Optional[Document] = None + projection: Optional[Dict[str, OneOrZero]] = None allowPartialResults: Optional[bool] = True batchSize: Optional[PositiveInt] = 101 - sort: Optional[Dict[str, OneOrMinusOne]] - limit: Optional[NonNegativeInt] + sort: Optional[Dict[str, OneOrMinusOne]] = None + limit: Optional[NonNegativeInt] = None class CommandResponse(BaseModel): @@ -49,7 +50,7 @@ class CollStatsCommandResponse(CommandResponse): ns: str size: float count: float - avgObjSize: Optional[float] + avgObjSize: Optional[float] = None storageSize: float totalIndexSize: float totalSize: float @@ -62,8 +63,8 @@ class CountCommandResponse(CommandResponse): class FindCommandResponseCursor(BaseModel): firstBatch: List[Document] - partialResultsReturned: Optional[bool] - id: Optional[int] + partialResultsReturned: Optional[bool] = None + id: Optional[int] = None ns: str @@ -74,7 +75,7 @@ class FindCommandResponse(CommandResponse): class DeleteCommandDelete(BaseModel): q: Document limit: OneOrZero - hint: Optional[Dict[str, OneOrMinusOne]] + hint: Optional[Dict[str, OneOrMinusOne]] = None class DeleteCommand(CommandBase): @@ -85,19 +86,19 @@ class DeleteCommand(CommandBase): class DeleteCommandResponse(CommandResponse): ok: OneOrZero n: NonNegativeInt - writeErrors: Optional[List[Document]] + writeErrors: Optional[List[Document]] = None class GetMoreCommand(CommandBase): getMore: int collection: str - batchSize: Optional[PositiveInt] + batchSize: Optional[PositiveInt] = None class GetMoreCommandResponseCursor(BaseModel): nextBatch: List[Document] - partialResultsReturned: Optional[bool] - id: Optional[int] + partialResultsReturned: Optional[bool] = None + id: Optional[int] = None ns: str @@ -138,10 +139,10 @@ class Query(BaseModel): class QueryRun(BaseModel): qid: str ran_at: datetime.datetime - result: Optional[Any] - error: Optional[Any] + result: Optional[Any] = None + error: Optional[Any] = None - @root_validator(skip_on_failure=True) + @model_validator(mode="before") def result_xor_error(cls, values): result, error = values.get("result"), values.get("error") if result is None and error is None: diff --git a/nmdc_runtime/api/models/run.py b/nmdc_runtime/api/models/run.py index 4f7cd760..49fa37ba 100644 --- a/nmdc_runtime/api/models/run.py +++ b/nmdc_runtime/api/models/run.py @@ -41,7 +41,7 @@ class JobSummary(OpenLineageBase): class Run(BaseModel): id: str - facets: Optional[dict] + facets: Optional[dict] = None class RunEventType(str, Enum): diff --git a/nmdc_runtime/api/models/util.py b/nmdc_runtime/api/models/util.py index 4a1413b9..ff922531 100644 --- a/nmdc_runtime/api/models/util.py +++ b/nmdc_runtime/api/models/util.py @@ -2,15 +2,15 @@ from fastapi import Query -from pydantic import BaseModel, root_validator, conint -from pydantic.generics import GenericModel +from pydantic import model_validator, Field, BaseModel +from typing_extensions import Annotated ResultT = TypeVar("ResultT") -class ListResponse(GenericModel, Generic[ResultT]): +class ListResponse(BaseModel, Generic[ResultT]): resources: List[ResultT] - next_page_token: Optional[str] + next_page_token: Optional[str] = None class ListRequest(BaseModel): @@ -21,7 +21,7 @@ class ListRequest(BaseModel): ), ] max_page_size: Optional[int] = 20 - page_token: Optional[str] + page_token: Optional[str] = None projection: Annotated[ Optional[str], Query( @@ -36,25 +36,25 @@ class ListRequest(BaseModel): ] -PerPageRange = conint(gt=0, le=2_000) +PerPageRange = Annotated[int, Field(gt=0, le=2_000)] class FindRequest(BaseModel): - filter: Optional[str] - search: Optional[str] - sort: Optional[str] - page: Optional[int] + filter: Optional[str] = None + search: Optional[str] = None + sort: Optional[str] = None + page: Optional[int] = None per_page: Optional[PerPageRange] = 25 - cursor: Optional[str] - group_by: Optional[str] + cursor: Optional[str] = None + group_by: Optional[str] = None fields: Annotated[ Optional[str], Query( description="comma-separated list of fields you want the objects in the response to include" ), - ] + ] = None - @root_validator(pre=True) + @model_validator(mode="before") def set_page_if_cursor_unset(cls, values): page, cursor = values.get("page"), values.get("cursor") if page is not None and cursor is not None: diff --git a/nmdc_runtime/api/models/workflow.py b/nmdc_runtime/api/models/workflow.py index 942b6f1e..a7740752 100644 --- a/nmdc_runtime/api/models/workflow.py +++ b/nmdc_runtime/api/models/workflow.py @@ -5,11 +5,11 @@ class WorkflowBase(BaseModel): - name: Optional[str] + name: Optional[str] = None description: Optional[str] = None - capability_ids: Optional[List[str]] + capability_ids: Optional[List[str]] = None class Workflow(WorkflowBase): id: str - created_at: Optional[datetime.datetime] + created_at: Optional[datetime.datetime] = None diff --git a/nmdc_runtime/api/v1/models/ingest.py b/nmdc_runtime/api/v1/models/ingest.py index 857383c2..a0e384f3 100644 --- a/nmdc_runtime/api/v1/models/ingest.py +++ b/nmdc_runtime/api/v1/models/ingest.py @@ -6,6 +6,6 @@ class Ingest(BaseModel): data_object_set: List[DataObject] = [] - read_qc_analysis_activity_set: Optional[List[ReadsQCSequencingActivity]] - metagenome_assembly_activity_set: Optional[List[ReadsQCSequencingActivity]] - metagenome_annotation_activity_set: Optional[List[ReadsQCSequencingActivity]] + read_qc_analysis_activity_set: Optional[List[ReadsQCSequencingActivity]] = None + metagenome_assembly_activity_set: Optional[List[ReadsQCSequencingActivity]] = None + metagenome_annotation_activity_set: Optional[List[ReadsQCSequencingActivity]] = None diff --git a/nmdc_runtime/infrastructure/database/impl/mongo/models/user.py b/nmdc_runtime/infrastructure/database/impl/mongo/models/user.py index d39a114a..586d4362 100644 --- a/nmdc_runtime/infrastructure/database/impl/mongo/models/user.py +++ b/nmdc_runtime/infrastructure/database/impl/mongo/models/user.py @@ -4,7 +4,7 @@ from typing import Optional, List from beanie import Document, Indexed -from pydantic import EmailStr +from pydantic import ConfigDict, EmailStr from nmdc_runtime.api.core.auth import verify_password from nmdc_runtime.domain.users.userSchema import UserAuth, UserUpdate, UserOut @@ -21,9 +21,8 @@ class DocumentMeta: full_name: Optional[str] = None site_admin: Optional[List[str]] = [] disabled: Optional[bool] = False - - class Config: - schema_extra = { + model_config = ConfigDict( + json_schema_extra={ "username": "bob", "email": "test@test.com", "full_name": "test", @@ -31,6 +30,7 @@ class Config: "site_admin": ["test_site"], "created_date": "1/1/2020", } + ) class UserQueries(IUserQueries): diff --git a/nmdc_runtime/minter/domain/model.py b/nmdc_runtime/minter/domain/model.py index c25dec48..28436509 100644 --- a/nmdc_runtime/minter/domain/model.py +++ b/nmdc_runtime/minter/domain/model.py @@ -9,7 +9,7 @@ class Entity(BaseModel): """A domain object whose attributes may change but has a recognizable identity over time.""" - id: str | None + id: str | None = None class ValueObject(BaseModel): @@ -65,7 +65,7 @@ class Identifier(Entity): typecode: Entity shoulder: Entity status: Status - bindings: Optional[dict] + bindings: Optional[dict] = None class Typecode(Entity): diff --git a/nmdc_runtime/util.py b/nmdc_runtime/util.py index 316cccdd..23e94dd1 100644 --- a/nmdc_runtime/util.py +++ b/nmdc_runtime/util.py @@ -17,13 +17,14 @@ from jsonschema.validators import Draft7Validator from nmdc_schema.nmdc_schema_accepting_legacy_ids import Database as NMDCDatabase from nmdc_schema.get_nmdc_view import ViewGetter -from pydantic import conint, BaseModel +from pydantic import Field, BaseModel from pymongo.database import Database as MongoDatabase from pymongo.errors import OperationFailure from toolz import merge, unique from nmdc_runtime.api.core.util import sha256hash_from_file from nmdc_runtime.api.models.object import DrsObjectIn +from typing_extensions import Annotated @lru_cache @@ -354,7 +355,7 @@ class UpdateStatement(BaseModel): class DeleteStatement(BaseModel): q: dict - limit: conint(ge=0, le=1) = 1 + limit: Annotated[int, Field(ge=0, le=1)] = 1 class OverlayDBError(Exception):