Skip to content

Commit

Permalink
Remove pydantic v1 compat helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
jl-wynen committed Nov 23, 2023
1 parent fdba3ff commit d5d5c92
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 211 deletions.
51 changes: 4 additions & 47 deletions src/scitacean/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
from dateutil.parser import parse as parse_datetime

from ._internal.orcid import is_valid_orcid
from ._internal.pydantic_compat import is_pydantic_v1
from .filesystem import RemotePath
from .logging import get_logger
from .pid import PID
from .thumbnail import Thumbnail

try:
# Python 3.11+
Expand Down Expand Up @@ -57,20 +53,9 @@ class DatasetType(str, Enum): # type: ignore[no-redef]
class BaseModel(pydantic.BaseModel):
"""Base class for Pydantic models for communication with SciCat."""

if is_pydantic_v1():

class Config:
extra = pydantic.Extra.forbid
json_encoders = { # noqa: RUF012
PID: lambda v: str(v),
RemotePath: lambda v: v.posix,
Thumbnail: lambda v: v.serialize(),
}

else:
model_config = pydantic.ConfigDict(
extra="forbid",
)
model_config = pydantic.ConfigDict(
extra="forbid",
)

_user_mask: ClassVar[Tuple[str, ...]]
_masked_fields: ClassVar[Optional[Tuple[str, ...]]] = None
Expand Down Expand Up @@ -107,7 +92,7 @@ def get_name(name: str, field: Any) -> Any:
return field.alias if field.alias is not None else name

field_names = {
get_name(name, field) for name, field in instance.get_model_fields().items()
get_name(name, field) for name, field in instance.model_fields.items()
}
default_mask = tuple(key for key in _IGNORED_KWARGS if key not in field_names)
cls._masked_fields = cls._user_mask + default_mask
Expand Down Expand Up @@ -137,34 +122,6 @@ def download_model_type(cls) -> Optional[Type[BaseModel]]:
"""
return None

if is_pydantic_v1():

@classmethod
def get_model_fields(cls) -> Dict[str, Any]:
return cls.__fields__ # type: ignore[return-value]

def model_dump(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return self.dict(*args, **kwargs)

def model_dump_json(self, *args: Any, **kwargs: Any) -> str:
return self.json(*args, **kwargs)

@classmethod
def model_construct(
cls: Type[ModelType], *args: Any, **kwargs: Any
) -> ModelType:
return cls.construct(*args, **kwargs)

@classmethod
def model_rebuild(cls, *args: Any, **kwargs: Any) -> Optional[bool]:
return cls.update_forward_refs(*args, **kwargs)

else:

@classmethod
def get_model_fields(cls) -> Dict[str, pydantic.fields.FieldInfo]:
return cls.model_fields


@dataclasses.dataclass
class BaseUserModel:
Expand Down
2 changes: 1 addition & 1 deletion src/scitacean/_html_repr/_attachment_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ def _strip_leading_underscore(s: str) -> str:


def _is_read_only(field_name: str) -> bool:
return field_name not in UploadAttachment.get_model_fields()
return field_name not in UploadAttachment.model_fields
34 changes: 0 additions & 34 deletions src/scitacean/_internal/pydantic_compat.py

This file was deleted.

49 changes: 17 additions & 32 deletions src/scitacean/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,10 @@
import re
from datetime import datetime, timezone
from pathlib import Path, PurePath
from typing import Any, Callable, Generator, Optional, TypeVar, Union
from typing import Any, Optional, TypeVar, Union

from ._internal.pydantic_compat import is_pydantic_v1

if not is_pydantic_v1():
from typing import Type

from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

def _get_remote_path_core_schema(
cls: Type[RemotePath], _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
core_schema.union_schema(
[core_schema.is_instance_schema(RemotePath), core_schema.str_schema()]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda p: p.posix if isinstance(p, RemotePath) else str(p),
info_arg=False,
return_schema=core_schema.str_schema(),
),
)
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema


class RemotePath:
Expand Down Expand Up @@ -176,16 +156,21 @@ def validate(cls, value: Union[str, RemotePath]) -> RemotePath:
"""Pydantic validator for RemotePath fields."""
return RemotePath(value)

if is_pydantic_v1():

@classmethod
def __get_validators__(
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
) -> Generator[Callable[[Union[str, RemotePath]], RemotePath], None, None]:
yield cls.validate

else:
__get_pydantic_core_schema__ = classmethod(_get_remote_path_core_schema)
core_schema.union_schema(
[core_schema.is_instance_schema(RemotePath), core_schema.str_schema()]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda p: p.posix if isinstance(p, RemotePath) else str(p),
info_arg=False,
return_schema=core_schema.str_schema(),
),
)


def _posix(path: Union[str, RemotePath]) -> str:
Expand Down
39 changes: 20 additions & 19 deletions src/scitacean/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@
validate_orcids,
)
from ._internal.dataclass_wrapper import dataclass_optional_args
from ._internal.pydantic_compat import field_validator
from .filesystem import RemotePath
from .pid import PID
from .thumbnail import Thumbnail
Expand Down Expand Up @@ -155,19 +154,21 @@ class DownloadDataset(
updatedBy: Optional[str] = None
validationStatus: Optional[str] = None

@field_validator("creationTime", "createdAt", "endTime", "updatedAt", mode="before")
@pydantic.field_validator(
"creationTime", "createdAt", "endTime", "updatedAt", mode="before"
)
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

@field_validator("history", mode="before")
@pydantic.field_validator("history", mode="before")
def _validate_drop(cls, value: Any) -> Any:
return validate_drop(value)

@field_validator("contactEmail", "ownerEmail", mode="before")
@pydantic.field_validator("contactEmail", "ownerEmail", mode="before")
def _validate_emails(cls, value: Any) -> Any:
return validate_emails(value)

@field_validator("orcidOfOwner", mode="before")
@pydantic.field_validator("orcidOfOwner", mode="before")
def _validate_orcids(cls, value: Any) -> Any:
return validate_orcids(value)

Expand Down Expand Up @@ -207,15 +208,15 @@ class UploadDerivedDataset(BaseModel):
techniques: Optional[List[UploadTechnique]] = None
validationStatus: Optional[str] = None

@field_validator("creationTime", mode="before")
@pydantic.field_validator("creationTime", mode="before")
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

@field_validator("contactEmail", "ownerEmail", mode="before")
@pydantic.field_validator("contactEmail", "ownerEmail", mode="before")
def _validate_emails(cls, value: Any) -> Any:
return validate_emails(value)

@field_validator("orcidOfOwner", mode="before")
@pydantic.field_validator("orcidOfOwner", mode="before")
def _validate_orcids(cls, value: Any) -> Any:
return validate_orcids(value)

Expand Down Expand Up @@ -257,15 +258,15 @@ class UploadRawDataset(BaseModel):
techniques: Optional[List[UploadTechnique]] = None
validationStatus: Optional[str] = None

@field_validator("creationTime", "endTime", mode="before")
@pydantic.field_validator("creationTime", "endTime", mode="before")
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

@field_validator("contactEmail", "ownerEmail", mode="before")
@pydantic.field_validator("contactEmail", "ownerEmail", mode="before")
def _validate_emails(cls, value: Any) -> Any:
return validate_emails(value)

@field_validator("orcidOfOwner", mode="before")
@pydantic.field_validator("orcidOfOwner", mode="before")
def _validate_orcids(cls, value: Any) -> Any:
return validate_orcids(value)

Expand All @@ -285,7 +286,7 @@ class DownloadAttachment(BaseModel):
updatedAt: Optional[datetime] = None
updatedBy: Optional[str] = None

@field_validator("createdAt", "updatedAt", mode="before")
@pydantic.field_validator("createdAt", "updatedAt", mode="before")
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

Expand Down Expand Up @@ -332,7 +333,7 @@ class DownloadOrigDatablock(BaseModel):
updatedAt: Optional[datetime] = None
updatedBy: Optional[str] = None

@field_validator("createdAt", "updatedAt", mode="before")
@pydantic.field_validator("createdAt", "updatedAt", mode="before")
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

Expand Down Expand Up @@ -372,7 +373,7 @@ class DownloadDatablock(BaseModel):
updatedAt: Optional[datetime] = None
updatedBy: Optional[str] = None

@field_validator("createdAt", "updatedAt", mode="before")
@pydantic.field_validator("createdAt", "updatedAt", mode="before")
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

Expand Down Expand Up @@ -410,7 +411,7 @@ class DownloadLifecycle(BaseModel):
retrieveReturnMessage: Optional[Dict[str, Any]] = None
retrieveStatusMessage: Optional[str] = None

@field_validator(
@pydantic.field_validator(
"archiveRetentionTime",
"dateOfDiskPurging",
"dateOfPublishing",
Expand Down Expand Up @@ -482,7 +483,7 @@ class DownloadHistory(BaseModel):
updatedAt: Optional[datetime] = None
updatedBy: Optional[datetime] = None

@field_validator("updatedAt", mode="before")
@pydantic.field_validator("updatedAt", mode="before")
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

Expand All @@ -500,7 +501,7 @@ class DownloadDataFile(BaseModel):
perm: Optional[str] = None
uid: Optional[str] = None

@field_validator("time", mode="before")
@pydantic.field_validator("time", mode="before")
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

Expand All @@ -518,7 +519,7 @@ class UploadDataFile(BaseModel):
perm: Optional[str] = None
uid: Optional[str] = None

@field_validator("time", mode="before")
@pydantic.field_validator("time", mode="before")
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

Expand Down Expand Up @@ -552,7 +553,7 @@ class DownloadSample(BaseModel):
updatedAt: Optional[datetime] = None
updatedBy: Optional[str] = None

@field_validator("createdAt", "updatedAt", mode="before")
@pydantic.field_validator("createdAt", "updatedAt", mode="before")
def _validate_datetime(cls, value: Any) -> Any:
return validate_datetime(value)

Expand Down
51 changes: 17 additions & 34 deletions src/scitacean/pid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,10 @@
from __future__ import annotations

import uuid
from typing import Callable, Generator, Optional, Union
from typing import Any, Optional, Union

import pydantic

from ._internal.pydantic_compat import is_pydantic_v1

if not is_pydantic_v1():
from typing import Any, Type

from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

def _get_pid_core_schema(
cls: Type[PID], _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
cls.parse,
core_schema.union_schema(
[core_schema.is_instance_schema(PID), core_schema.str_schema()]
),
serialization=core_schema.plain_serializer_function_ser_schema(
cls.__str__, info_arg=False, return_schema=core_schema.str_schema()
),
)
from pydantic import GetCoreSchemaHandler, ValidationError
from pydantic_core import core_schema


class PID:
Expand Down Expand Up @@ -148,15 +128,18 @@ def validate(cls, value: Union[str, PID]) -> PID:
return PID.parse(value)
if isinstance(value, PID):
return value
raise pydantic.ValidationError("expected a PID or str")
raise ValidationError("expected a PID or str")

if is_pydantic_v1():

@classmethod
def __get_validators__(
cls,
) -> Generator[Callable[[Union[str, PID]], PID], None, None]:
yield cls.validate

else:
__get_pydantic_core_schema__ = classmethod(_get_pid_core_schema)
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
cls.parse,
core_schema.union_schema(
[core_schema.is_instance_schema(PID), core_schema.str_schema()]
),
serialization=core_schema.plain_serializer_function_ser_schema(
cls.__str__, info_arg=False, return_schema=core_schema.str_schema()
),
)
Loading

0 comments on commit d5d5c92

Please sign in to comment.