Skip to content

Commit

Permalink
Tree: Switch to Pydantic 2
Browse files Browse the repository at this point in the history
Pydantic 2 has more modern methods and stability compared to Pydantic 1

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Dec 19, 2023
1 parent f631dd6 commit 51ca1ff
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 15 deletions.
5 changes: 4 additions & 1 deletion OAI/types/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Optional
from gen_logging import LogConfig
Expand Down Expand Up @@ -45,6 +45,9 @@ class ModelLoadRequest(BaseModel):
draft: Optional[DraftModelLoadRequest] = None

class ModelLoadResponse(BaseModel):
# Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces = [])

model_type: str = "model"
module: int
modules: int
Expand Down
4 changes: 2 additions & 2 deletions auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load_auth_keys():
try:
with open("api_tokens.yml", "r", encoding = 'utf8') as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
auth_keys = AuthKeys.parse_obj(auth_keys_dict)
auth_keys = AuthKeys.model_validate(auth_keys_dict)
except Exception as _:
new_auth_keys = AuthKeys(
api_key = secrets.token_hex(16),
Expand All @@ -39,7 +39,7 @@ def load_auth_keys():
auth_keys = new_auth_keys

with open("api_tokens.yml", "w", encoding = "utf8") as auth_file:
yaml.safe_dump(auth_keys.dict(), auth_file, default_flow_style=False)
yaml.safe_dump(auth_keys.model_dump(), auth_file, default_flow_style=False)

print(
f"Your API key is: {auth_keys.api_key}\n"
Expand Down
2 changes: 1 addition & 1 deletion gen_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def update_from_dict(options_dict: Dict[str, bool]):
if value is None:
value = False

config = LogConfig.parse_obj(options_dict)
config = LogConfig.model_validate(options_dict)

def broadcast_status():
enabled = []
Expand Down
14 changes: 7 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / data.name

load_data = data.dict()
load_data = data.model_dump()

# TODO: Add API exception if draft directory isn't found
draft_config = unwrap(model_config.get("draft"), {})
Expand Down Expand Up @@ -156,7 +156,7 @@ async def generator():
status="finished"
)

yield get_sse_packet(response.json(ensure_ascii = False))
yield get_sse_packet(response.model_dump_json())

# Switch to model progress if the draft model is loaded
if model_container.draft_config:
Expand All @@ -171,7 +171,7 @@ async def generator():
status="processing"
)

yield get_sse_packet(response.json(ensure_ascii=False))
yield get_sse_packet(response.model_dump_json())
except CancelledError:
print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.")
except Exception as e:
Expand Down Expand Up @@ -230,7 +230,7 @@ async def load_lora(data: LoraLoadRequest):
if len(model_container.active_loras) > 0:
model_container.unload(True)

result = model_container.load_loras(lora_dir, **data.dict())
result = model_container.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse(
success = unwrap(result.get("success"), []),
failure = unwrap(result.get("failure"), [])
Expand Down Expand Up @@ -281,7 +281,7 @@ async def generator():
completion_tokens,
model_path.name)

yield get_sse_packet(response.json(ensure_ascii=False))
yield get_sse_packet(response.model_dump_json())
except CancelledError:
print("Error: Completion request cancelled by user.")
except Exception as e:
Expand Down Expand Up @@ -334,15 +334,15 @@ async def generator():
model_path.name
)

yield get_sse_packet(response.json(ensure_ascii=False))
yield get_sse_packet(response.model_dump_json())

# Yield a finish response on successful generation
finish_response = create_chat_completion_stream_chunk(
const_id,
finish_reason = "stop"
)

yield get_sse_packet(finish_response.json(ensure_ascii=False))
yield get_sse_packet(finish_response.model_dump_json())
except CancelledError:
print("Error: Chat completion cancelled by user.")
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion requirements-amd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.1

# Pip dependencies
fastapi
pydantic < 2,>= 1
pydantic
PyYAML
progress
uvicorn
Expand Down
2 changes: 1 addition & 1 deletion requirements-cu118.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.1

# Pip dependencies
fastapi
pydantic < 2,>= 1
pydantic
PyYAML
progress
uvicorn
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.1

# Pip dependencies
fastapi
pydantic < 2,>= 1
pydantic
PyYAML
progress
uvicorn
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_generator_error(message: str):

# Log and send the exception
print(f"\n{generator_error.error.trace}")
return get_sse_packet(generator_error.json(ensure_ascii = False))
return get_sse_packet(generator_error.model_dump_json())

def get_sse_packet(json_data: str):
return f"data: {json_data}\n\n"
Expand Down

0 comments on commit 51ca1ff

Please sign in to comment.