|  | 
|  | 1 | +from pathlib import Path | 
|  | 2 | +from typing import List, Union, Optional, Literal, Any, Annotated, Type, TypeVar | 
|  | 3 | +from pydantic import BaseModel, Field, field_validator, field_serializer | 
|  | 4 | +from pydantic import ValidationError | 
|  | 5 | +import json | 
|  | 6 | +import logging | 
|  | 7 | + | 
|  | 8 | +logger = logging.getLogger(__name__) | 
|  | 9 | + | 
|  | 10 | +INTERNAL_DEFAULT_HANDLE = "__result__" | 
|  | 11 | +T = TypeVar("T", bound="PythonWorkflowDefinitionWorkflow") | 
|  | 12 | + | 
|  | 13 | +__all__ = ( | 
|  | 14 | +    "PythonWorkflowDefinitionInputNode", | 
|  | 15 | +    "PythonWorkflowDefinitionOutputNode", | 
|  | 16 | +    "PythonWorkflowDefinitionFunctionNode", | 
|  | 17 | +    "PythonWorkflowDefinitionEdge", | 
|  | 18 | +    "PythonWorkflowDefinitionWorkflow", | 
|  | 19 | +) | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +class PythonWorkflowDefinitionBaseNode(BaseModel): | 
|  | 23 | +    """Base model for all node types, containing common fields.""" | 
|  | 24 | + | 
|  | 25 | +    id: int | 
|  | 26 | +    # The 'type' field will be overridden in subclasses with Literal types | 
|  | 27 | +    # to enable discriminated unions. | 
|  | 28 | +    type: str | 
|  | 29 | + | 
|  | 30 | + | 
|  | 31 | +class PythonWorkflowDefinitionInputNode(PythonWorkflowDefinitionBaseNode): | 
|  | 32 | +    """Model for input nodes.""" | 
|  | 33 | + | 
|  | 34 | +    type: Literal["input"] | 
|  | 35 | +    name: str | 
|  | 36 | +    value: Optional[Any] = None | 
|  | 37 | + | 
|  | 38 | + | 
|  | 39 | +class PythonWorkflowDefinitionOutputNode(PythonWorkflowDefinitionBaseNode): | 
|  | 40 | +    """Model for output nodes.""" | 
|  | 41 | + | 
|  | 42 | +    type: Literal["output"] | 
|  | 43 | +    name: str | 
|  | 44 | + | 
|  | 45 | + | 
|  | 46 | +class PythonWorkflowDefinitionFunctionNode(PythonWorkflowDefinitionBaseNode): | 
|  | 47 | +    """ | 
|  | 48 | +    Model for function execution nodes. | 
|  | 49 | +    The 'name' attribute is computed automatically from 'value'. | 
|  | 50 | +    """ | 
|  | 51 | + | 
|  | 52 | +    type: Literal["function"] | 
|  | 53 | +    value: str  # Expected format: 'module.function' | 
|  | 54 | + | 
|  | 55 | +    @field_validator("value") | 
|  | 56 | +    @classmethod | 
|  | 57 | +    def check_value_format(cls, v: str): | 
|  | 58 | +        if not v or "." not in v or v.startswith(".") or v.endswith("."): | 
|  | 59 | +            msg = ( | 
|  | 60 | +                "FunctionNode 'value' must be a non-empty string ", | 
|  | 61 | +                "in 'module.function' format with at least one period.", | 
|  | 62 | +            ) | 
|  | 63 | +            raise ValueError(msg) | 
|  | 64 | +        return v | 
|  | 65 | + | 
|  | 66 | + | 
|  | 67 | +# Discriminated Union for Nodes | 
|  | 68 | +PythonWorkflowDefinitionNode = Annotated[ | 
|  | 69 | +    Union[ | 
|  | 70 | +        PythonWorkflowDefinitionInputNode, | 
|  | 71 | +        PythonWorkflowDefinitionOutputNode, | 
|  | 72 | +        PythonWorkflowDefinitionFunctionNode, | 
|  | 73 | +    ], | 
|  | 74 | +    Field(discriminator="type"), | 
|  | 75 | +] | 
|  | 76 | + | 
|  | 77 | + | 
|  | 78 | +class PythonWorkflowDefinitionEdge(BaseModel): | 
|  | 79 | +    """Model for edges connecting nodes.""" | 
|  | 80 | + | 
|  | 81 | +    target: int | 
|  | 82 | +    targetPort: Optional[str] = None | 
|  | 83 | +    source: int | 
|  | 84 | +    sourcePort: Optional[str] = None | 
|  | 85 | + | 
|  | 86 | +    @field_validator("sourcePort", mode="before") | 
|  | 87 | +    @classmethod | 
|  | 88 | +    def handle_default_source(cls, v: Any) -> Optional[str]: | 
|  | 89 | +        """ | 
|  | 90 | +        Transforms incoming None/null for sourcePort to INTERNAL_DEFAULT_HANDLE. | 
|  | 91 | +        Runs before standard validation. | 
|  | 92 | +        """ | 
|  | 93 | +        # Allow not specifying the sourcePort -> null gets resolved to __result__ | 
|  | 94 | +        if v is None: | 
|  | 95 | +            return INTERNAL_DEFAULT_HANDLE | 
|  | 96 | +        elif v == INTERNAL_DEFAULT_HANDLE: | 
|  | 97 | +            # Disallow explicit use of the internal reserved handle name | 
|  | 98 | +            msg = ( | 
|  | 99 | +                f"Explicit use of reserved sourcePort '{INTERNAL_DEFAULT_HANDLE}' " | 
|  | 100 | +                f"is not allowed. Use null/None for default output." | 
|  | 101 | +            ) | 
|  | 102 | +            raise ValueError(msg) | 
|  | 103 | +        return v | 
|  | 104 | + | 
|  | 105 | +    @field_serializer("sourcePort") | 
|  | 106 | +    def serialize_source_handle(self, v: Optional[str]) -> Optional[str]: | 
|  | 107 | +        """ | 
|  | 108 | +        SERIALIZATION (Output): Converts internal INTERNAL_DEFAULT_HANDLE ("__result__") | 
|  | 109 | +        back to None. | 
|  | 110 | +        """ | 
|  | 111 | +        if v == INTERNAL_DEFAULT_HANDLE: | 
|  | 112 | +            return None  # Map "__result__" back to None for JSON output | 
|  | 113 | +        return v  # Keep other handle names as they are | 
|  | 114 | + | 
|  | 115 | + | 
|  | 116 | +class PythonWorkflowDefinitionWorkflow(BaseModel): | 
|  | 117 | +    """The main workflow model.""" | 
|  | 118 | + | 
|  | 119 | +    nodes: List[PythonWorkflowDefinitionNode] | 
|  | 120 | +    edges: List[PythonWorkflowDefinitionEdge] | 
|  | 121 | + | 
|  | 122 | +    def dump_json( | 
|  | 123 | +        self, | 
|  | 124 | +        *, | 
|  | 125 | +        indent: Optional[int] = 2, | 
|  | 126 | +        **kwargs, | 
|  | 127 | +    ) -> str: | 
|  | 128 | +        """ | 
|  | 129 | +        Dumps the workflow model to a JSON string. | 
|  | 130 | +
 | 
|  | 131 | +        Args: | 
|  | 132 | +            indent: JSON indentation level. | 
|  | 133 | +            exclude_computed_function_names: If True (default), excludes the computed | 
|  | 134 | +                                             'name' field from FunctionNode objects | 
|  | 135 | +                                             in the output. | 
|  | 136 | +            **kwargs: Additional keyword arguments passed to Pydantic's model_dump. | 
|  | 137 | +
 | 
|  | 138 | +        Returns: | 
|  | 139 | +            JSON string representation of the workflow. | 
|  | 140 | +        """ | 
|  | 141 | + | 
|  | 142 | +        # Dump the model to a dictionary first, using mode='json' for compatible types | 
|  | 143 | +        # Pass any extra kwargs (like custom 'exclude' rules for other fields) | 
|  | 144 | +        workflow_dict = self.model_dump(mode="json", **kwargs) | 
|  | 145 | + | 
|  | 146 | +        # Dump the dictionary to a JSON string | 
|  | 147 | +        try: | 
|  | 148 | +            json_string = json.dumps(workflow_dict, indent=indent) | 
|  | 149 | +            logger.info("Successfully dumped workflow model to JSON string.") | 
|  | 150 | +            return json_string | 
|  | 151 | +        except TypeError as e: | 
|  | 152 | +            logger.error( | 
|  | 153 | +                f"Error serializing workflow dictionary to JSON: {e}", exc_info=True | 
|  | 154 | +            ) | 
|  | 155 | +            raise  # Re-raise after logging | 
|  | 156 | + | 
|  | 157 | +    def dump_json_file( | 
|  | 158 | +        self, | 
|  | 159 | +        file_name: Union[str, Path], | 
|  | 160 | +        *, | 
|  | 161 | +        indent: Optional[int] = 2, | 
|  | 162 | +        **kwargs, | 
|  | 163 | +    ) -> None: | 
|  | 164 | +        """ | 
|  | 165 | +        Dumps the workflow model to a JSON file. | 
|  | 166 | +
 | 
|  | 167 | +        Args: | 
|  | 168 | +            file_path: Path to the output JSON file. | 
|  | 169 | +            indent: JSON indentation level. | 
|  | 170 | +            exclude_computed_function_names: If True, excludes the computed 'name' field | 
|  | 171 | +                                             from FunctionNode objects. | 
|  | 172 | +            **kwargs: Additional keyword arguments passed to Pydantic's model_dump. | 
|  | 173 | +        """ | 
|  | 174 | +        logger.info(f"Dumping workflow model to JSON file: {file_name}") | 
|  | 175 | +        # Pass kwargs to dump_json, which passes them to model_dump | 
|  | 176 | +        json_string = self.dump_json( | 
|  | 177 | +            indent=indent, | 
|  | 178 | +            **kwargs, | 
|  | 179 | +        ) | 
|  | 180 | +        try: | 
|  | 181 | +            with open(file_name, "w", encoding="utf-8") as f: | 
|  | 182 | +                f.write(json_string) | 
|  | 183 | +            logger.info(f"Successfully wrote workflow model to {file_name}.") | 
|  | 184 | +        except IOError as e: | 
|  | 185 | +            logger.error( | 
|  | 186 | +                f"Error writing workflow model to file {file_name}: {e}", exc_info=True | 
|  | 187 | +            ) | 
|  | 188 | +            raise | 
|  | 189 | + | 
|  | 190 | +    @classmethod | 
|  | 191 | +    def load_json_str(cls: Type[T], json_data: Union[str, bytes]) -> dict: | 
|  | 192 | +        """ | 
|  | 193 | +        Loads and validates workflow data from a JSON string or bytes. | 
|  | 194 | +
 | 
|  | 195 | +        Args: | 
|  | 196 | +            json_data: The JSON data as a string or bytes. | 
|  | 197 | +
 | 
|  | 198 | +        Returns: | 
|  | 199 | +            An instance of PwdWorkflow. | 
|  | 200 | +
 | 
|  | 201 | +        Raises: | 
|  | 202 | +            pydantic.ValidationError: If validation fails. | 
|  | 203 | +            json.JSONDecodeError: If json_data is not valid JSON. | 
|  | 204 | +        """ | 
|  | 205 | +        logger.info("Loading workflow model from JSON data...") | 
|  | 206 | +        try: | 
|  | 207 | +            # Pydantic v2 method handles bytes or str directly | 
|  | 208 | +            instance = cls.model_validate_json(json_data) | 
|  | 209 | +            # Pydantic v1 equivalent: instance = cls.parse_raw(json_data) | 
|  | 210 | +            logger.info( | 
|  | 211 | +                "Successfully loaded and validated workflow model from JSON data." | 
|  | 212 | +            ) | 
|  | 213 | +            return instance.model_dump() | 
|  | 214 | +        except ValidationError:  # Catch validation errors specifically | 
|  | 215 | +            logger.error("Workflow model validation failed.", exc_info=True) | 
|  | 216 | +            raise | 
|  | 217 | +        except json.JSONDecodeError:  # Catch JSON parsing errors specifically | 
|  | 218 | +            logger.error("Invalid JSON format encountered.", exc_info=True) | 
|  | 219 | +            raise | 
|  | 220 | +        except Exception as e:  # Catch any other unexpected errors | 
|  | 221 | +            logger.error( | 
|  | 222 | +                f"An unexpected error occurred during JSON loading: {e}", exc_info=True | 
|  | 223 | +            ) | 
|  | 224 | +            raise | 
|  | 225 | + | 
|  | 226 | +    @classmethod | 
|  | 227 | +    def load_json_file(cls: Type[T], file_name: Union[str, Path]) -> dict: | 
|  | 228 | +        """ | 
|  | 229 | +        Loads and validates workflow data from a JSON file. | 
|  | 230 | +
 | 
|  | 231 | +        Args: | 
|  | 232 | +            file_path: The path to the JSON file. | 
|  | 233 | +
 | 
|  | 234 | +        Returns: | 
|  | 235 | +            An instance of PwdWorkflow. | 
|  | 236 | +
 | 
|  | 237 | +        Raises: | 
|  | 238 | +            FileNotFoundError: If the file is not found. | 
|  | 239 | +            pydantic.ValidationError: If validation fails. | 
|  | 240 | +            json.JSONDecodeError: If the file is not valid JSON. | 
|  | 241 | +            IOError: If there are other file reading issues. | 
|  | 242 | +        """ | 
|  | 243 | +        logger.info(f"Loading workflow model from JSON file: {file_name}") | 
|  | 244 | +        try: | 
|  | 245 | +            file_content = Path(file_name).read_text(encoding="utf-8") | 
|  | 246 | +            # Delegate validation to the string loading method | 
|  | 247 | +            return cls.load_json_str(file_content) | 
|  | 248 | +        except FileNotFoundError: | 
|  | 249 | +            logger.error(f"JSON file not found: {file_name}", exc_info=True) | 
|  | 250 | +            raise | 
|  | 251 | +        except IOError as e: | 
|  | 252 | +            logger.error(f"Error reading JSON file {file_name}: {e}", exc_info=True) | 
|  | 253 | +            raise | 
0 commit comments