Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into refactor/make-output-…
Browse files Browse the repository at this point in the history
…schema-mandatory
  • Loading branch information
srijanpatel committed Jan 31, 2025
2 parents 5e40e5d + 1eda598 commit ca7e0a6
Show file tree
Hide file tree
Showing 19 changed files with 241 additions and 157 deletions.
3 changes: 1 addition & 2 deletions backend/app/api/evals_management.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends
from sqlalchemy.orm import Session
from pathlib import Path
import yaml
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any
from datetime import datetime, timezone

from ..database import get_db
Expand Down
3 changes: 0 additions & 3 deletions backend/app/api/openai_compatible_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import json
from datetime import datetime, timezone
from typing import Dict, Any, List, Optional, Union
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from pydantic import BaseModel
from sqlalchemy.orm import Session

from ..schemas.workflow_schemas import WorkflowDefinitionSchema
from ..execution.workflow_executor import WorkflowExecutor
from ..models.workflow_model import WorkflowModel
from ..database import get_db
from .workflow_run import run_workflow_blocking
Expand Down
28 changes: 18 additions & 10 deletions backend/app/execution/workflow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import ValidationError

from ..nodes.base import BaseNodeOutput
from ..nodes.base import BaseNode, BaseNodeOutput
from ..nodes.factory import NodeFactory

from ..schemas.workflow_schemas import (
Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(
self.task_recorder = None
self.context = context
self._node_dict: Dict[str, WorkflowNodeSchema] = {}
self.node_instances: Dict[str, BaseNode] = {}
self._dependencies: Dict[str, Set[str]] = {}
self._node_tasks: Dict[str, asyncio.Task[Optional[BaseNodeOutput]]] = {}
self._initial_inputs: Dict[str, Dict[str, Any]] = {}
Expand Down Expand Up @@ -172,7 +173,7 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]:
for dep_id in dependency_ids
),
)
except Exception as e:
except Exception:
raise UpstreamFailure(
f"Node {node_id} skipped due to upstream failure"
)
Expand Down Expand Up @@ -257,6 +258,7 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]:
node_instance = NodeFactory.create_node(
node_name=node.title, node_type_name=node.node_type, config=node.config
)
self.node_instances[node_id] = node_instance
# Update task recorder
if self.task_recorder:
self.task_recorder.update_task(
Expand Down Expand Up @@ -300,7 +302,7 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]:
f"Node Type: {node.node_type}\n"
f"Node Title: {node.title}\n"
f"Inputs: {node_input}\n"
f"Error: {str(e)}"
f"Error: {traceback.format_exc()}"
)
print(error_msg)
self._failed_nodes.add(node_id)
Expand All @@ -317,17 +319,23 @@ async def run(
self,
input: Dict[str, Any] = {},
node_ids: List[str] = [],
precomputed_outputs: Dict[str, Dict[str, Any]] = {},
precomputed_outputs: Dict[str, Dict[str, Any] | List[Dict[str, Any]]] = {},
) -> Dict[str, BaseNodeOutput]:
# Handle precomputed outputs first
if precomputed_outputs:
for node_id, output in precomputed_outputs.items():
try:
self._outputs[node_id] = NodeFactory.create_node(
node_name=self._node_dict[node_id].title,
node_type_name=self._node_dict[node_id].node_type,
config=self._node_dict[node_id].config,
).output_model.model_validate(output)
if isinstance(output, dict):
self._outputs[node_id] = NodeFactory.create_node(
node_name=self._node_dict[node_id].title,
node_type_name=self._node_dict[node_id].node_type,
config=self._node_dict[node_id].config,
).output_model.model_validate(output)
else:
# If output is a list of dicts, do not validate the output
# these are outputs of loop nodes, their precomputed outputs are not supported yet
continue

except ValidationError as e:
print(
f"[WARNING]: Precomputed output validation failed for node {node_id}: {e}\n skipping precomputed output"
Expand Down Expand Up @@ -387,7 +395,7 @@ async def __call__(
self,
input: Dict[str, Any] = {},
node_ids: List[str] = [],
precomputed_outputs: Dict[str, Dict[str, Any]] = {},
precomputed_outputs: Dict[str, Dict[str, Any] | List[Dict[str, Any]]] = {},
) -> Dict[str, BaseNodeOutput]:
"""
Execute the workflow with the given input data.
Expand Down
14 changes: 12 additions & 2 deletions backend/app/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,13 @@ def create_output_model_class(
else (field_type, ...) # try as is
)
for field_name, field_type in output_schema.items()
}, # type: ignore
},
__base__=BaseNodeOutput,
__config__=None,
__doc__=f"Output model for {self.name} node",
__module__=self.__module__,
__validators__=None,
__cls_kwargs__=None,
)

def create_composite_model_instance(
Expand All @@ -142,10 +147,15 @@ def create_composite_model_instance(
return create_model(
model_name,
**{
instance.__class__.__name__: (instance.__class__, ...) # type: ignore
instance.__class__.__name__: (instance.__class__, ...)
for instance in instances
},
__base__=BaseNodeInput,
__config__=None,
__doc__=f"Input model for {self.name} node",
__module__=self.__module__,
__validators__=None,
__cls_kwargs__=None,
)

async def __call__(
Expand Down
9 changes: 5 additions & 4 deletions backend/app/nodes/llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import os
import re
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, cast
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from docx2python import docx2python

import litellm
Expand Down Expand Up @@ -494,7 +493,7 @@ async def completion_with_backoff(**kwargs) -> str:
return response.choices[0].message.content

except Exception as e:
logging.error(f"=== LLM Request Error ===")
logging.error("=== LLM Request Error ===")
# Create a save copy of kwargs without sensitive information
save_config = kwargs.copy()
save_config["api_key"] = "********" if "api_key" in save_config else None
Expand Down Expand Up @@ -548,6 +547,8 @@ async def generate_text(
output_json_schema = convert_output_schema_to_json_schema(output_schema)
elif output_json_schema is not None and output_json_schema.strip() != "":
output_json_schema = json.loads(output_json_schema)
else:
raise ValueError("Invalid output schema", output_schema, output_json_schema)
output_json_schema["additionalProperties"] = False

# check if the model supports response format
Expand Down Expand Up @@ -783,7 +784,7 @@ def convert_docx_to_xml(file_path: str) -> str:
try:
with docx2python(file_path) as docx_content:
# Convert the document content to XML format
xml_content = f"<?xml version='1.0' encoding='UTF-8'?>\n<document>\n"
xml_content = "<?xml version='1.0' encoding='UTF-8'?>\n<document>\n"

# Add metadata
xml_content += "<metadata>\n"
Expand Down
5 changes: 2 additions & 3 deletions backend/app/nodes/llm/generative/best_of_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BestOfNNodeConfig(SingleLLMCallNodeConfig, BaseSubworkflowNodeConfig):
description="System message for the generation LLM",
)
user_message: str = Field(default="", description="User message template")
output_schema: Dict[str, str] = Field(default={"response": "str"})
output_schema: Dict[str, str] = Field(default={"response": "string"})


class BestOfNNodeInput(BaseNodeInput):
Expand Down Expand Up @@ -107,7 +107,7 @@ def setup_subworkflow(self) -> None:
"llm_info": self.config.llm_info.model_dump(),
"system_message": self.config.rating_prompt,
"user_message": "",
"output_schema": {"rating": "float"},
"output_schema": {"rating": "number"},
},
)
nodes.append(rate_node)
Expand Down Expand Up @@ -171,7 +171,6 @@ def setup_subworkflow(self) -> None:
id=output_node_id,
node_type="OutputNode",
config={
"output_schema": output_schema,
"output_map": {
f"{k}": f"pick_one_node.{k}" for k in output_schema.keys()
},
Expand Down
24 changes: 16 additions & 8 deletions backend/app/nodes/logic/coalesce.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,31 +50,39 @@ async def run(self, input: BaseModel) -> BaseModel:
for key in self.config.preferences: # {{ edit_1 }}
if key in data and data[key] is not None:
# Return the first non-None value according to preferences
output_model = create_model( # type: ignore
output_model = create_model(
f"{self.name}",
**{
k: (type(v), ...) for k, v in data[key].items()
}, # Only include the first non-null key # type: ignore
}, # Only include the first non-null key
__base__=CoalesceNodeOutput,
__config__=None,
__module__=self.__module__,
__doc__=f"Output model for {self.name} node",
__validators__=None,
__cls_kwargs__=None,
)
self.output_model = output_model
first_non_null_output = data[key]
return self.output_model(**first_non_null_output) # type: ignore
return self.output_model(**first_non_null_output)

# If all preferred values are None, check the rest of the data
for key, value in data.items():
if value is not None:
# Return the first non-None value immediately
output_model = create_model( # type: ignore
output_model = create_model(
f"{self.name}",
**{
key: (type(value), ...)
}, # Only include the first non-null key # type: ignore
**{key: (type(value), ...)}, # Only include the first non-null key
__base__=CoalesceNodeOutput,
__config__=None,
__module__=self.__module__,
__doc__=f"Output model for {self.name} node",
__validators__=None,
__cls_kwargs__=None,
)
self.output_model = output_model
first_non_null_output[key] = value
return self.output_model(**first_non_null_output) # type: ignore
return self.output_model(**first_non_null_output)

# If all values are None, return an empty output
return None # type: ignore
29 changes: 21 additions & 8 deletions backend/app/nodes/logic/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,30 +130,43 @@ async def run(self, input: BaseModel) -> BaseModel:
Evaluates conditions for each route in order. The first route that matches
gets the input data. If no routes match, the first route acts as a default.
"""
output_model = create_model( # type: ignore
output_model = create_model(
f"{self.name}",
**{field_name: (field_type, ...) for field_name, field_type in input.model_fields.items()}, # type: ignore
__config__=None,
__base__=RouterNodeOutput,
__doc__=f"Output model for {self.name} node",
__module__=self.__module__,
__validators__=None,
__cls_kwargs__=None,
**{
field_name: (field_type, None)
for field_name, field_type in input.model_fields.items()
},
)
# Create fields for each route with Optional[input type]
route_fields = { # type: ignore
route_name: (Optional[output_model], None) # type: ignore
route_fields = {
route_name: (Optional[output_model], None)
for route_name in self.config.route_map.keys()
}
new_output_model = create_model( # type: ignore
new_output_model = create_model(
f"{self.name}CompositeOutput",
__base__=RouterNodeOutput,
**route_fields, # type: ignore
__config__=None,
__doc__=f"Composite output model for {self.name} node",
__module__=self.__module__,
__validators__=None,
__cls_kwargs__=None,
**route_fields,
)
self.output_model = new_output_model

output: Dict[str, Optional[BaseModel]] = {}

for route_name, route in self.config.route_map.items():
if self._evaluate_route_conditions(input, route):
output[route_name] = output_model(**input.model_dump()) # type: ignore
output[route_name] = output_model(**input.model_dump())

return self.output_model(**output) # type: ignore
return self.output_model(**output)


if __name__ == "__main__":
Expand Down
39 changes: 26 additions & 13 deletions backend/app/nodes/loops/base_loop_subworkflow_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import abstractmethod
from typing import Any, Dict, List
from pydantic import BaseModel
from pydantic import BaseModel, create_model

from ..primitives.output import OutputNode

from ..base import BaseNodeInput, BaseNodeOutput
from ...execution.workflow_executor import WorkflowExecutor
Expand Down Expand Up @@ -62,9 +64,10 @@ async def run_iteration(self, input: Dict[str, Any]) -> Dict[str, Any]:
iteration_input = {**input, "loop_history": self.loop_outputs}

# Execute the subworkflow
workflow_executor = WorkflowExecutor(
workflow=self.subworkflow, context=self.context
self._executor = WorkflowExecutor(
workflow=self.config.subworkflow, context=self.context
)
workflow_executor = self._executor
outputs = await workflow_executor.run(iteration_input)

# Convert outputs to dict format
Expand All @@ -83,16 +86,6 @@ async def run_iteration(self, input: Dict[str, Any]) -> Dict[str, Any]:

async def run(self, input: BaseModel) -> BaseModel:
"""Execute the loop subworkflow until stopping condition is met"""
# Create output model dynamically based on the schema of the output node
output_node = next(
node
for node in self.config.subworkflow.nodes
if node.node_type == "OutputNode"
)
self.output_model = self.create_output_model_class(
output_node.config.get("output_schema", {})
)

current_input = self._map_input(input)

# Run iterations until stopping condition is met
Expand All @@ -103,5 +96,25 @@ async def run(self, input: BaseModel) -> BaseModel:

self.subworkflow_output = self.loop_outputs

# create output model for the loop from the subworkflow output node's output_model
output_node = next(
node
for _id, node in self._executor.node_instances.items()
if issubclass(node.__class__, OutputNode)
)
self.output_model = create_model(
f"{self.name}",
**{
name: (field, ...)
for name, field in output_node.output_model.model_fields.items()
},
__base__=BaseLoopSubworkflowNodeOutput,
__config__=None,
__module__=self.__module__,
__cls_kwargs__={"arbitrary_types_allowed": True},
__doc__=None,
__validators__=None,
)

# Return final state as BaseModel
return self.output_model.model_validate(current_input) # type: ignore
Loading

0 comments on commit ca7e0a6

Please sign in to comment.