Skip to content

Commit

Permalink
workflow deep copy
Browse files Browse the repository at this point in the history
  • Loading branch information
manthanguptaa committed Nov 29, 2024
1 parent c787e86 commit 6fe82f6
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
9 changes: 5 additions & 4 deletions phi/playground/workflow_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,12 @@ async def run_workflow(workflow_id: str, body: WorkflowRunRequest):
workflow = get_workflow_by_id(workflows, workflow_id)
if workflow is None:
raise HTTPException(status_code=404, detail="Workflow not found")
workflow.user_id = body.user_id
if workflow._run_return_type == "RunResponse":
return workflow.run(**body.input)
workflow_copy = workflow.deep_copy(update={"workflow_id": workflow_id})
workflow_copy.user_id = body.user_id
if workflow_copy._run_return_type == "RunResponse":
return workflow_copy.run(**body.input)
return StreamingResponse(
(r.model_dump_json() for r in workflow.run(**body.input)), media_type="text/event-stream"
(r.model_dump_json() for r in workflow_copy.run(**body.input)), media_type="text/event-stream"
)

@workflow_router.post("/workflow/{workflow_id}/session/all")
Expand Down
65 changes: 65 additions & 0 deletions phi/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from types import GeneratorType
from typing import Any, Optional, Callable, Dict

from phi.agent.agent import Agent
from pydantic import BaseModel, Field, ConfigDict, field_validator, PrivateAttr

from phi.run.response import RunResponse, RunEvent # noqa: F401
Expand Down Expand Up @@ -366,3 +367,67 @@ def delete_session(self, session_id: str):
if self.storage is None:
raise ValueError("Storage is not set")
self.storage.delete_session(session_id)

def deep_copy(self, *, update: Optional[Dict[str, Any]] = None) -> "Workflow":
"""Create and return a deep copy of this Workflow, optionally updating fields.
Args:
update (Optional[Dict[str, Any]]): Optional dictionary of fields for the new Workflow.
Returns:
Workflow: A new Workflow instance.
"""
# Extract the fields to set for the new Workflow
fields_for_new_workflow = {}

for field_name in self.model_fields_set:
field_value = getattr(self, field_name)
if field_value is not None:
if isinstance(field_value, Agent):
fields_for_new_workflow[field_name] = field_value.deep_copy()
else:
fields_for_new_workflow[field_name] = self._deep_copy_field(field_name, field_value)

# Update fields if provided
if update:
fields_for_new_workflow.update(update)

# Create a new Workflow
new_workflow = self.__class__(**fields_for_new_workflow)
logger.debug(f"Created new Workflow: workflow_id: {new_workflow.workflow_id} | session_id: {new_workflow.session_id}")
return new_workflow

def _deep_copy_field(self, field_name: str, field_value: Any) -> Any:
"""Helper method to deep copy a field based on its type."""
from copy import copy, deepcopy

# For memory, use its deep_copy method
if field_name == "memory":
return field_value.deep_copy()

# For compound types, attempt a deep copy
if isinstance(field_value, (list, dict, set, WorkflowStorage)):
try:
return deepcopy(field_value)
except Exception as e:
logger.warning(f"Failed to deepcopy field: {field_name} - {e}")
try:
return copy(field_value)
except Exception as e:
logger.warning(f"Failed to copy field: {field_name} - {e}")
return field_value

# For pydantic models, attempt a deep copy
if isinstance(field_value, BaseModel):
try:
return field_value.model_copy(deep=True)
except Exception as e:
logger.warning(f"Failed to deepcopy field: {field_name} - {e}")
try:
return field_value.model_copy(deep=False)
except Exception as e:
logger.warning(f"Failed to copy field: {field_name} - {e}")
return field_value

# For other types, return as is
return field_value

0 comments on commit 6fe82f6

Please sign in to comment.