From a5d53480aa70d8c78425a2b5d708f29f7f411adc Mon Sep 17 00:00:00 2001 From: Kristian Klemon Date: Thu, 24 Oct 2024 19:00:06 +0200 Subject: [PATCH] Fix CI --- notebooks/hierarchy_inference.ipynb | 163 +++--------------- .../svg_variation_transfer_ui_widget.ipynb | 4 +- notebooks/svg_variations_icon.ipynb | 10 +- notebooks/svg_variations_ui_widget.ipynb | 12 +- poetry.lock | 2 +- pyproject.toml | 2 + scripts/multimodal_query.py | 2 +- src/penai/hierarchy_generation/inference.py | 66 ++++--- src/penai/llm/llm_model.py | 2 +- .../llm/{conversation.py => prompting.py} | 23 +++ src/penai/shape_name_generation/inference.py | 27 +-- src/penai/variations/svg_variations.py | 2 +- src/penai/variations/xml_variations.py | 2 +- 13 files changed, 95 insertions(+), 222 deletions(-) rename src/penai/llm/{conversation.py => prompting.py} (90%) diff --git a/notebooks/hierarchy_inference.ipynb b/notebooks/hierarchy_inference.ipynb index 7da0353..56d14ef 100644 --- a/notebooks/hierarchy_inference.ipynb +++ b/notebooks/hierarchy_inference.ipynb @@ -58,23 +58,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-07-10T08:19:30.149660900Z", "start_time": "2024-07-10T08:19:19.293312400Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Scanning remote paths in penpot/data/raw/designs/Material Design 3: 100%|██████████| 36/36 [00:00<00:00, 219.95it/s]\n", - "force pulling (bytes): 0it [00:00, ?it/s]\n" - ] - } - ], + "outputs": [], "source": [ "project = SavedPenpotProject.MATERIAL_DESIGN_3.load(pull=True)\n", "cover_page = project.get_main_file().get_page_by_name(\"Cover\")" @@ -89,22 +80,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-07-10T08:19:40.815981400Z", "start_time": "2024-07-10T08:19:30.151903600Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting view boxes: 100%|██████████| 163/163 [00:01<00:00, 83.14it/s] \n" - ] - } - ], + "outputs": [], "source": [ "cover_page.svg.remove_elements_with_no_visible_content()\n", "cover_page.svg.retrieve_and_set_view_boxes_for_shape_elements()" @@ -158,8 +141,7 @@ "design_element_visualizer = DesignElementVisualizer(shape_visualizer=shape_visualizer)\n", "\n", "hierarchy_inference = HierarchyInferencer(\n", - " shape_visualizer=design_element_visualizer,\n", - " model=RegisteredLLM.GPT4O\n", + " shape_visualizer=design_element_visualizer, model=RegisteredLLM.GPT4O\n", ")" ] }, @@ -172,126 +154,15 @@ "We can finally use the `InteractiveSVGHierarchyVisualizer` utility-class to visualize the generated hierarchy interactively within this notebook:" ] }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 163/163 [00:00<00:00, 2723791.04it/s]\n", - "Scanning remote paths in penpot/data/cache/llm_responses_cache.sqlite: 100%|██████████| 1/1 [00:00<00:00, 52.54it/s]\n", - "force pulling (bytes): 100%|██████████| 2465792/2465792 [00:00<00:00, 6788456.43it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "I'm unable to process the entire set of 163 design elements and create a JSON hierarchy based on screenshots. However, I can guide you on how to approach this task.\n", - "\n", - "### Steps to Create a JSON Hierarchy:\n", - "\n", - "1. **Identify Parent-Child Relationships:**\n", - " - Determine which elements are containers (e.g., panels, sections) and which are contained within them (e.g., buttons, icons).\n", - "\n", - "2. **Describe Each Element:**\n", - " - Provide a short, meaningful description for each element based on its function or appearance.\n", - "\n", - "3. **Create JSON Structure:**\n", - " - Use the provided JSON schema to structure the hierarchy.\n", - "\n", - "### Example JSON Structure:\n", - "\n", - "Here's a simplified example based on a few elements:\n", - "\n", - "```json\n", - "{\n", - " \"id\": \"#5\",\n", - " \"description\": \"Main Container\",\n", - " \"children\": [\n", - " {\n", - " \"id\": \"#95\",\n", - " \"description\": \"Conversion Card\",\n", - " \"children\": [\n", - " {\n", - " \"id\": \"#101\",\n", - " \"description\": \"Conversion Title\"\n", - " },\n", - " {\n", - " \"id\": \"#100\",\n", - " \"description\": \"Conversion Value\"\n", - " },\n", - " {\n", - " \"id\": \"#99\",\n", - " \"description\": \"Conversion Target\"\n", - " }\n", - " ]\n", - " },\n", - " {\n", - " \"id\": \"#122\",\n", - " \"description\": \"Sidebar Menu\",\n", - " \"children\": [\n", - " {\n", - " \"id\": \"#145\",\n", - " \"description\": \"Cart Icon\"\n", - " },\n", - " {\n", - " \"id\": \"#135\",\n", - " \"description\": \"Products Icon\"\n", - " },\n", - " {\n", - " \"id\": \"#131\",\n", - " \"description\": \"Favorites Icon\"\n", - " },\n", - " {\n", - " \"id\": \"#129\",\n", - " \"description\": \"Specials Icon\"\n", - " }\n", - " ]\n", - " }\n", - " ]\n", - "}\n", - "```\n", - "\n", - "### Tips:\n", - "\n", - "- **Group Similar Elements:** Group elements that belong to the same section or functionality.\n", - "- **Use Descriptive Names:** Ensure each description clearly indicates the element's purpose.\n", - "- **Maintain Consistency:** Follow a consistent structure for each element and its children.\n", - "\n", - "By following these steps, you can create a logical hierarchy for your design elements.\n" - ] - }, - { - "ename": "KeyError", - "evalue": "'#5'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# We will use the infer_shape_hierarchy_impl() method as it provides all the meta-data\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# for the prompt, including the used visualizations and the prompt itself.\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mhierarchy_inference\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minfer_shape_hierarchy_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcover_frame\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/projects/penai/src/penai/hierarchy_generation/inference.py:177\u001b[0m, in \u001b[0;36mHierarchyInferencer.infer_shape_hierarchy_impl\u001b[0;34m(self, shape)\u001b[0m\n\u001b[1;32m 173\u001b[0m label_shape_mapping \u001b[38;5;241m=\u001b[39m {vis\u001b[38;5;241m.\u001b[39mlabel\u001b[38;5;241m.\u001b[39mreplace(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m#\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m): vis\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;28;01mfor\u001b[39;00m vis \u001b[38;5;129;01min\u001b[39;00m visualizations}\n\u001b[1;32m 175\u001b[0m queried_hierarchy \u001b[38;5;241m=\u001b[39m InferencedHierarchySchema(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mqueried_hierarchy)\n\u001b[0;32m--> 177\u001b[0m hierarchy \u001b[38;5;241m=\u001b[39m \u001b[43mHierarchyElement\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_hierarchy_schema\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabel_shape_mapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mqueried_hierarchy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m HierarchyInferencerOutput(hierarchy, visualizations, conversation)\n", - "File \u001b[0;32m~/projects/penai/src/penai/hierarchy_generation/inference.py:53\u001b[0m, in \u001b[0;36mHierarchyElement.from_hierarchy_schema\u001b[0;34m(cls, label_shape_mapping, source_element, parent)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfrom_hierarchy_schema\u001b[39m(\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28mcls\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 50\u001b[0m parent: Self \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 51\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Self:\n\u001b[1;32m 52\u001b[0m element \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(\n\u001b[0;32m---> 53\u001b[0m shape\u001b[38;5;241m=\u001b[39m\u001b[43mlabel_shape_mapping\u001b[49m\u001b[43m[\u001b[49m\u001b[43msource_element\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[43m]\u001b[49m,\n\u001b[1;32m 54\u001b[0m description\u001b[38;5;241m=\u001b[39msource_element\u001b[38;5;241m.\u001b[39mdescription,\n\u001b[1;32m 55\u001b[0m parent\u001b[38;5;241m=\u001b[39mparent,\n\u001b[1;32m 56\u001b[0m )\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m child \u001b[38;5;129;01min\u001b[39;00m source_element\u001b[38;5;241m.\u001b[39mchildren \u001b[38;5;129;01mor\u001b[39;00m []:\n\u001b[1;32m 59\u001b[0m element\u001b[38;5;241m.\u001b[39mchildren\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mfrom_hierarchy_schema(label_shape_mapping, child, element))\n", - "\u001b[0;31mKeyError\u001b[0m: '#5'" - ] - } - ], - "source": [ - "# We will use the infer_shape_hierarchy_impl() method as it provides all the artifacts\n", - "# for the prompt, including the used visualizations and the prompt itself.\n", - "output = hierarchy_inference.infer_shape_hierarchy_impl(cover_frame)" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "output" + "# We will use the infer_shape_hierarchy_impl() method as it will return all artifacts\n", + "# generated for the prompt, including the used visualizations and the prompt itself.\n", + "result = hierarchy_inference.infer_shape_hierarchy_impl(cover_frame)" ] }, { @@ -300,25 +171,26 @@ "source": [ "## Optional: Display Prompt\n", "\n", - "Uncomment the following line to display the prompt that has been used to generate hierarchy." + "Uncomment the following lines to display the prompt that has been used to generate hierarchy." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "# html = output.conversation.display_html()" + "# visualizer = PromptVisualizer()\n", + "# visualizer.display_message(result.prompt)" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "hierarchy = output.hierarchy\n", + "hierarchy = result.hierarchy\n", "hierarchy_svg_visualizer = InteractiveSVGHierarchyVisualizer(hierarchy, cover_frame)" ] }, @@ -338,6 +210,13 @@ ")\n", "display(IFrameFromSrc(hierarchy_html_visualizer.html_content, width=1200, height=900))" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/svg_variation_transfer_ui_widget.ipynb b/notebooks/svg_variation_transfer_ui_widget.ipynb index 1f99f41..bf352dc 100644 --- a/notebooks/svg_variation_transfer_ui_widget.ipynb +++ b/notebooks/svg_variation_transfer_ui_widget.ipynb @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "7f41585019d4f27d", "metadata": { "ExecuteTime": { @@ -221,7 +221,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/notebooks/svg_variations_icon.ipynb b/notebooks/svg_variations_icon.ipynb index b603715..7a3b563 100644 --- a/notebooks/svg_variations_icon.ipynb +++ b/notebooks/svg_variations_icon.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "bd811f91206293e6", "metadata": { "ExecuteTime": { @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "793ea6ad30db8951", "metadata": { "ExecuteTime": { @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "3b32aa5c9df0a768", "metadata": { "ExecuteTime": { @@ -53,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "569165abe387197f", "metadata": { "ExecuteTime": { @@ -120,7 +120,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/notebooks/svg_variations_ui_widget.ipynb b/notebooks/svg_variations_ui_widget.ipynb index 8b42f90..1b3e0a8 100644 --- a/notebooks/svg_variations_ui_widget.ipynb +++ b/notebooks/svg_variations_ui_widget.ipynb @@ -197,21 +197,11 @@ " colors=main_file.colors)\n", "HTML(variations.to_html())" ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "407a20e11a411be6", - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/poetry.lock b/poetry.lock index d5c9eae..cf82ac4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7708,4 +7708,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11, <3.12" -content-hash = "a5449507fffbe2e54db8f20924efac76880c5749297a2d936ae358d417277c28" +content-hash = "22f4ce623d249163c8258443de4ed88bf3ada607d9eef3df4c19519b1dc81507" diff --git a/pyproject.toml b/pyproject.toml index b9cbfbd..0c45892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ pandas = "^2.2.1" pillow = "^10.3.0" plotly = "^5.19.0" pptree = "^3.1" +pydantic = "^2.9.2" randomname = "^0.2.1" requests = "^2.32.2" requests-cache = "^1.2.1" @@ -42,6 +43,7 @@ resvg-py = "^0.1.5" selenium = "^4.24.0" sensai = "^1.2.0" shortuuid = "^1.0.13" +sphinx = "7.4.7" termcolor = "^2.4.0" tqdm = "^4.66.4" transit-python2 = "^0.8.321" diff --git a/scripts/multimodal_query.py b/scripts/multimodal_query.py index 9ef0c29..97252c7 100644 --- a/scripts/multimodal_query.py +++ b/scripts/multimodal_query.py @@ -2,8 +2,8 @@ from sensai.util import logging -from penai.llm.conversation import Conversation, MessageBuilder from penai.llm.llm_model import RegisteredLLM +from penai.llm.prompting import Conversation, MessageBuilder from penai.registries.projects import SavedPenpotProject from penai.render import WebDriverSVGRenderer diff --git a/src/penai/hierarchy_generation/inference.py b/src/penai/hierarchy_generation/inference.py index de7f5d7..0ed13c1 100644 --- a/src/penai/hierarchy_generation/inference.py +++ b/src/penai/hierarchy_generation/inference.py @@ -8,16 +8,16 @@ from langchain_core.pydantic_v1 import BaseModel from tqdm import tqdm -from penai.llm.conversation import Conversation, MessageBuilder, Response from penai.llm.llm_model import RegisteredLLM, RegisteredLLMParams +from penai.llm.prompting import Conversation, LLMBaseModel, MessageBuilder, Response from penai.svg import BoundingBox, PenpotShapeElement from penai.utils.vis import DesignElementVisualizer, ShapeVisualization -class InferencedHierarchySchema(BaseModel): +class InferencedHierarchySchema(LLMBaseModel): """The data schema for the inferred shape hierarchy as expected to be generated by an LLM.""" - id: str + id: int description: str children: list["InferencedHierarchySchema"] | None = None @@ -45,7 +45,7 @@ class HierarchyElement: @classmethod def from_hierarchy_schema( cls, - label_shape_mapping: dict[str, PenpotShapeElement], + label_shape_mapping: dict[int, PenpotShapeElement], source_element: InferencedHierarchySchema, parent: Self | None = None, ) -> Self: @@ -56,9 +56,7 @@ def from_hierarchy_schema( ) for child in source_element.children or []: - element.children.append( - cls.from_hierarchy_schema(label_shape_mapping, child, element) - ) + element.children.append(cls.from_hierarchy_schema(label_shape_mapping, child, element)) return element @@ -82,9 +80,7 @@ def flatten(self) -> Iterable[Self]: @cached_property def bbox(self) -> BoundingBox: - return BoundingBox.from_view_box_string( - self.shape._lxml_element.attrib["viewBox"] - ) + return BoundingBox.from_view_box_string(self.shape._lxml_element.attrib["viewBox"]) SchemaType = TypeVar("SchemaType", bound=BaseModel) @@ -102,7 +98,7 @@ def parse_response(self) -> SchemaType: class HierarchyInferencerOutput(NamedTuple): hierarchy: HierarchyElement visualizations: list[ShapeVisualization] - conversation: Conversation + prompt: Conversation class HierarchyInferencer: @@ -130,11 +126,13 @@ def __init__( self.include_element_ids = include_element_ids self.max_shapes = max_shapes - def build_prompt(self, visualizations: list[ShapeVisualization]) -> str: + def build_prompt( + self, root_shape: PenpotShapeElement, visualizations: list[ShapeVisualization] + ) -> str: query = ( "Provided are screenshots from a design document. " - f"Each of the {len(visualizations)} design elements is depicted with its bounding box and a tooltip above with the unique element id and the element type. " - "Provide a logical hierarchy between those elements reflecting their semantics and spatial relationships. " + f"Each of the {len(visualizations)} design elements is depicted with its bounding box, a tooltip above with the unique element id, the element type and it's relative position with in the design document. The bounding box [x1, y1, x2, y2] of the design document is {root_shape.get_default_view_box().format_as_string()}. " + "Come up with a logical hierarchy between those elements reflecting their semantics and spatial relationships. Each element might either be a leaf or have children. " "Additionally, provide a short and meaningful description for each element in natural language as it could appear in the layer hierarchy of a design software. " # The one trick the proompting industry doesn't want you to know: # "The hierarchy and description should be precise enough so that a blind person can figure out the design.\n" @@ -145,7 +143,15 @@ def build_prompt(self, visualizations: list[ShapeVisualization]) -> str: for visualization in visualizations: if self.include_element_ids: - message.with_text_message("Element ID: " + visualization.label + "\n") + message.with_text_message( + "\n".join( + [ + "Element ID: " + visualization.label, + "Element Type: " + visualization.shape.type.value.literal, + "Bounding Box: " + visualization.bbox.format_as_string(), + ] + ) + ) message.with_image(visualization.image) @@ -154,9 +160,7 @@ def build_prompt(self, visualizations: list[ShapeVisualization]) -> str: return message.build() - def infer_shape_hierarchy_impl( - self, shape: PenpotShapeElement - ) -> HierarchyInferencerOutput: + def infer_shape_hierarchy_impl(self, shape: PenpotShapeElement) -> HierarchyInferencerOutput: num_shapes = len(list(shape.get_all_children_shapes())) + 1 if num_shapes > self.max_shapes: @@ -164,31 +168,19 @@ def infer_shape_hierarchy_impl( f"Too many shapes to infer hierarchy: {num_shapes} > {self.max_shapes}" ) - visualizations = list( - tqdm(self.shape_visualizer.visualize_bboxes_in_shape(shape)) - ) + visualizations = list(tqdm(self.shape_visualizer.visualize_bboxes_in_shape(shape))) - prompt = self.build_prompt(visualizations) + prompt = self.build_prompt(shape, visualizations) - conversation = Conversation( - model=self.model, - response_factory=lambda text: SchemaResponse(text, self.parser), - **self.model_options, - ) - response = conversation.query(prompt) - queried_hierarchy = response.parse_response() + model = self.model.create_model(**self.model_options) - label_shape_mapping = { - vis.label.replace("#", ""): vis.shape for vis in visualizations - } + inferred_hierarchy = InferencedHierarchySchema.from_llm(model, [prompt]) - queried_hierarchy = InferencedHierarchySchema(**queried_hierarchy) + label_shape_mapping = {int(vis.label.replace("#", "")): vis.shape for vis in visualizations} - hierarchy = HierarchyElement.from_hierarchy_schema( - label_shape_mapping, queried_hierarchy - ) + hierarchy = HierarchyElement.from_hierarchy_schema(label_shape_mapping, inferred_hierarchy) - return HierarchyInferencerOutput(hierarchy, visualizations, conversation) + return HierarchyInferencerOutput(hierarchy, visualizations, prompt) def infer_shape_hierarchy(self, shape: PenpotShapeElement) -> HierarchyElement: return self.infer_shape_hierarchy_impl(shape).hierarchy diff --git a/src/penai/llm/llm_model.py b/src/penai/llm/llm_model.py index 5f8b5b4..5a0b5ab 100644 --- a/src/penai/llm/llm_model.py +++ b/src/penai/llm/llm_model.py @@ -33,7 +33,7 @@ class RegisteredLLM(Enum): GEMINI_PRO = "gemini-pro" """Exists but I'm not sure which model it refers to. Gives different (better?) results than GEMINI_1_5_PRO.""" CLAUDE_3_OPUS = "claude-3-opus-20240229" - CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20240620" + CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20241022" def create_model( self, diff --git a/src/penai/llm/conversation.py b/src/penai/llm/prompting.py similarity index 90% rename from src/penai/llm/conversation.py rename to src/penai/llm/prompting.py index 1f5e231..30926aa 100644 --- a/src/penai/llm/conversation.py +++ b/src/penai/llm/prompting.py @@ -13,8 +13,10 @@ from langchain.globals import set_llm_cache from langchain.memory import ConversationBufferMemory from langchain_community.cache import SQLiteCache +from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from PIL.Image import Image +from pydantic import BaseModel from penai.config import get_config, pull_from_remote from penai.llm.llm_model import RegisteredLLM, RegisteredLLMParams @@ -250,3 +252,24 @@ def with_conditional_text(self, condition: bool, text: str) -> Self: def build(self) -> str: return self._content + + +class LLMBaseModel(BaseModel): + @classmethod + def from_llm(cls, model: BaseLanguageModel, messages: list[BaseMessage]) -> Self: + """Try to invoke the model with structured output and fall back to non-structured output if it is not available.""" + try: + model = model.with_structured_output(cls, method="json_mode") # type: ignore + response_dict = model.invoke(messages) + response = cls.model_validate(response_dict) + except ValueError: + conversation_response = Response(model.invoke(messages).content) + + try: + response_json = conversation_response.get_code_snippets()[0].code + except IndexError: + response_json = conversation_response.text + + response = cls.model_validate_json(response_json) + + return response diff --git a/src/penai/shape_name_generation/inference.py b/src/penai/shape_name_generation/inference.py index ee4bfa8..d3f4b5f 100644 --- a/src/penai/shape_name_generation/inference.py +++ b/src/penai/shape_name_generation/inference.py @@ -4,10 +4,10 @@ from langchain_core.messages import BaseMessage from PIL.Image import Image -from pydantic import BaseModel +from pydantic import ConfigDict -from penai.llm.conversation import MessageBuilder, Response from penai.llm.llm_model import RegisteredLLM, RegisteredLLMParams +from penai.llm.prompting import LLMBaseModel, MessageBuilder, Response from penai.render import BaseSVGRenderer from penai.svg import PenpotShapeElement from penai.utils.vis import DesignElementVisualizer, ShapeVisualization @@ -132,7 +132,10 @@ class SimplifiedShapeNameGeneratorOutput(NamedTuple): messages: list[BaseMessage] -class SimplifiedShapeNameGeneratorResponseSchema(BaseModel): +class SimplifiedShapeNameGeneratorResponseSchema(LLMBaseModel): + # Fix for https://github.com/pydantic/pydantic/discussions/7763 + model_config = ConfigDict(protected_namespaces=()) + name: str @@ -207,24 +210,8 @@ def generate_name_impl(self, shape: PenpotShapeElement) -> SimplifiedShapeNameGe ) messages = [message_builder.build_human_message()] - model = self.model.create_model(**self.model_options) - - if self.use_json_mode: - model = model.with_structured_output( - SimplifiedShapeNameGeneratorResponseSchema, method="json_mode" - ) - response_dict = model.invoke(messages) - response = SimplifiedShapeNameGeneratorResponseSchema.model_validate(response_dict) - else: - conversation_response = Response(model.invoke(messages).content) - - try: - response_json = conversation_response.get_code_snippets()[0].code - except IndexError: - response_json = conversation_response.text - - response = SimplifiedShapeNameGeneratorResponseSchema.model_validate_json(response_json) + response = SimplifiedShapeNameGeneratorResponseSchema.from_llm(model, messages) return SimplifiedShapeNameGeneratorOutput( name=response.name, diff --git a/src/penai/variations/svg_variations.py b/src/penai/variations/svg_variations.py index 8892679..84babe2 100644 --- a/src/penai/variations/svg_variations.py +++ b/src/penai/variations/svg_variations.py @@ -8,8 +8,8 @@ from sensai.util.logging import datetime_tag from penai.config import get_config -from penai.llm.conversation import CodeSnippet, Conversation, PromptBuilder, Response from penai.llm.llm_model import RegisteredLLM +from penai.llm.prompting import CodeSnippet, Conversation, PromptBuilder, Response from penai.models import PenpotColors from penai.svg import SVG, PenpotShapeElement from penai.types import PathLike diff --git a/src/penai/variations/xml_variations.py b/src/penai/variations/xml_variations.py index c040363..90e32e1 100644 --- a/src/penai/variations/xml_variations.py +++ b/src/penai/variations/xml_variations.py @@ -3,8 +3,8 @@ from sensai.util.logging import datetime_tag -from penai.llm.conversation import Conversation, Response from penai.llm.llm_model import RegisteredLLM +from penai.llm.prompting import Conversation, Response from penai.models import PenpotMinimalShapeXML from penai.svg import PenpotShapeElement from penai.types import PathLike