Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance(client): support pydantic v1 and v2 #3030

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/client.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ jobs:
working-directory: ./client
run: |
make install-dev-req
# force using pydantic v2 to do lint checking
python -m pip install -U 'pydantic==2.*'
make install-sw

- name: Black Format Check
Expand Down Expand Up @@ -100,6 +102,9 @@ jobs:
os:
- macos-latest
- ubuntu-latest
pydantic-version:
- "1"
- "2"
runs-on: ${{ matrix.os }}
defaults:
run:
Expand Down Expand Up @@ -134,6 +139,7 @@ jobs:
run: |
make install-dev-req
make install-sw
python -m pip install -U "pydantic==${{matrix.pydantic-version}}.*"

- name: Git Config
run: |
Expand Down
3 changes: 2 additions & 1 deletion client/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ gen-model:
--custom-file-header "# Generated by make gen-model. DO NOT EDIT!" \
--url $(OPEN_API_URL)/v3/api-docs/Api \
--snake-case-field \
--reuse-model \
--base-class starwhale.base.models.base.SwBaseModel \
--output-model-type pydantic_v2.BaseModel \
--custom-template-dir model_gen_templates \
--output starwhale/base/client/models/models.py

47 changes: 47 additions & 0 deletions client/model_gen_templates/pydantic/BaseModel.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comment }}{% endif %}
{%- if description %}
"""
{{ description | indent(4) }}
"""
{%- endif %}
{%- if not fields and not description %}
pass
{%- endif %}
{%- if config %}
{%- filter indent(4) %}
{% include 'Config.jinja2' %}
{%- endfilter %}
{%- endif %}
{%- for field in fields -%}
{%- if not field.annotated and field.field %}
{%- if field.type_hint[:7] == 'constr(' %}
{{ field.name }}: str = {{ field.field }}
{%- else %}
{%- if field.type_hint[:16] == 'Optional[constr(' %}
{{ field.name }}: Optional[str] = {{ field.field }}
{%- else %}
{{ field.name }}: {{ field.type_hint }} = {{ field.field | replace(", unique_items=True", "") }}
{%- endif %}
{%- endif %}
{%- else %}
{%- if field.annotated %}
{{ field.name }}: {{ field.annotated }}
{%- else %}
{{ field.name }}: {{ field.type_hint }}
{%- endif %}
{%- if not (field.required or (field.represented_default == 'None' and field.strip_default_none))
%} = {{ field.represented_default }}
{%- endif -%}
{%- endif %}
{%- if field.docstring %}
"""
{{ field.docstring | indent(4) }}
"""
{%- endif %}
{%- for method in methods -%}
{{ method }}
{%- endfor -%}
{%- endfor -%}
3 changes: 1 addition & 2 deletions client/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@
"trio>=0.22.0",
"httpx>=0.22.0",
"urllib3<1.27",
"pydantic>=2.0.0",
"pydantic-settings",
"pydantic",
"sortedcontainers",
"importlib_resources",
# workaround: email-validator 2.1.0 has a syntax error in python 3.7, but the email-validator is necessary for fastapi.
Expand Down
17 changes: 2 additions & 15 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import requests
import tenacity
import jsonlines
from pydantic import validator
from sortedcontainers import SortedDict, SortedList, SortedValuesView
from typing_extensions import Protocol

Expand Down Expand Up @@ -82,6 +81,7 @@ def encode_schema(type: "SwType", **kwargs: Any) -> ColumnSchemaDesc:
if isinstance(type, SwScalarType):
return ColumnSchemaDesc(type=str(type), **kwargs)
if isinstance(type, SwTupleType):
# https://github.com/pydantic/pydantic/issues/6022#issuecomment-1668916345
ret = ColumnSchemaDesc(
type="TUPLE",
element_type=SwType.encode_schema(type.main_type),
Expand Down Expand Up @@ -1557,19 +1557,6 @@ class TombstoneDesc(SwBaseModel):
# This works only if the key is a string
key_prefix: Optional[str] = None

@validator("end")
def end_must_be_greater_than_start(
cls, v: KeyType, values: Dict[str, Any]
) -> KeyType:
if v is None:
return None
if values["start"] is not None:
if not isinstance(v, type(values["start"])):
raise ValueError("end has different type with start")
if v <= values["start"]:
raise ValueError("end must be greater than start")
return v # type: ignore

def __gt__(self, other: object) -> Any:
if not isinstance(other, TombstoneDesc):
raise NotImplementedError
Expand Down Expand Up @@ -2445,7 +2432,7 @@ def _dump_manifest(self, manifest: LocalDataStoreManifest) -> None:
manifest_file = Path(self.root_path) / datastore_manifest_file_name
with filelock.FileLock(str(Path(self.root_path) / ".lock")):
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
tmp.write(manifest.model_dump_json(indent=2))
tmp.write(json.dumps(manifest.to_dict()))
tmp.flush()
shutil.move(tmp.name, manifest_file)

Expand Down
4 changes: 1 addition & 3 deletions client/starwhale/api/_impl/service/types/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
from typing import Any, Set, Dict, List, Callable, Optional

from pydantic import BaseModel
from pydantic.dataclasses import dataclass

from starwhale.base.client.models.models import ComponentSpecValueType

from .types import ServiceType, ComponentSpec


@dataclass
class Message:
class Message(BaseModel):
content: str
role: str

Expand Down
24 changes: 13 additions & 11 deletions client/starwhale/base/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import typing

import requests
from pydantic import BaseModel
from requests import exceptions
from tenacity import (
retry,
Expand All @@ -14,16 +13,13 @@
retry_if_exception_type,
wait_random_exponential,
)
from pydantic.tools import parse_obj_as
from fastapi.encoders import jsonable_encoder
from pydantic.type_adapter import TypeAdapter

from starwhale.utils import console
from starwhale.base.models.base import ListFilter
from starwhale.base.models.base import RespType, ListFilter, SwBaseModel, obj_to_model
from starwhale.base.client.models.base import ResponseCode

T = typing.TypeVar("T")
RespType = typing.TypeVar("RespType", bound=BaseModel)
CLIENT_DEFAULT_RETRY_ATTEMPTS = 10


Expand Down Expand Up @@ -54,10 +50,10 @@ def __init__(self, type_: typing.Type[RespType], data: typing.Any) -> None:
self._type = type_
self._data = data
self._raise_on_error = False
self._base = TypeAdapter(ResponseCode).validate_python(data)
self._base = obj_to_model(data, ResponseCode)
self._response: T | None = None
if self.is_success():
self._response = parse_obj_as(self._type, self._data) # type: ignore
self._response = obj_to_model(self._data, self._type) # type: ignore
else:
console.debug(f"request failed, response msg: {self._base.message}")

Expand Down Expand Up @@ -105,12 +101,12 @@ def http_request(
self,
method: str,
uri: str,
json: dict | BaseModel | None = None,
json: dict | SwBaseModel | None = None,
params: dict | None = None,
data: typing.Any = None,
) -> typing.Any:
_json: typing.Any = json
if isinstance(json, BaseModel):
if isinstance(json, SwBaseModel):
# convert to dict with proper alias
_json = jsonable_encoder(json.dict(by_alias=True, exclude_none=True))
resp = self.session.request(
Expand Down Expand Up @@ -140,12 +136,18 @@ def http_get(self, uri: str, params: dict | None = None) -> typing.Any:
return self.http_request("GET", uri, params=params)

def http_post(
self, uri: str, json: dict | BaseModel | None = None, data: typing.Any = None
self,
uri: str,
json: dict | SwBaseModel | None = None,
data: typing.Any = None,
) -> typing.Any:
return self.http_request("POST", uri, json=json, data=data)

def http_put(
self, uri: str, json: dict | BaseModel | None = None, params: dict | None = None
self,
uri: str,
json: dict | SwBaseModel | None = None,
params: dict | None = None,
) -> typing.Any:
return self.http_request("PUT", uri, json=json, params=params)

Expand Down
4 changes: 2 additions & 2 deletions client/starwhale/base/client/models/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import typing

from pydantic import BaseModel
from starwhale.base.models.base import SwBaseModel

Model = typing.TypeVar("Model")


class ResponseCode(BaseModel):
class ResponseCode(SwBaseModel):
code: str
message: str
Loading
Loading