Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect and report infinite loops #1652

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
find loops and preserve extra_info in errors
  • Loading branch information
chrisgoringe committed Oct 2, 2023
commit 7499011289d21f4f452b9eb189e7a36da00dd04f
36 changes: 28 additions & 8 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):



def validate_inputs(prompt, item, validated):
def validate_inputs(prompt, item, validated, stack=[]):
unique_id = item
if unique_id in validated:
return validated[unique_id]
Expand All @@ -399,6 +399,20 @@ def validate_inputs(prompt, item, validated):
errors = []
valid = True

if unique_id in stack:
error = {
"type": "infinite_loop",
"message": "loop detected in workflow validation",
"details": f"detected at {unique_id}",
"extra_info": {"stack": f"{stack}"},
}
errors.append(error)
ret = (False, errors, unique_id)
validated[unique_id] = ret
# don't continue, because we're already here further up the stack
return ret
stack.append(unique_id)

for x in required_inputs:
if x not in inputs:
error = {
Expand Down Expand Up @@ -450,7 +464,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
r = validate_inputs(prompt, o_id, validated, stack)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
Expand Down Expand Up @@ -582,8 +596,10 @@ def validate_inputs(prompt, item, validated):
else:
ret = (True, [], unique_id)

validated[unique_id] = ret
return ret
# if we had a loop, unique_id will have been marked invalid further down the tree
if unique_id not in validated:
validated[unique_id] = ret
return validated[unique_id]

def full_type_name(klass):
module = klass.__module__
Expand Down Expand Up @@ -615,7 +631,7 @@ def validate_prompt(prompt):
valid = False
reasons = []
try:
m = validate_inputs(prompt, o, validated)
m = validate_inputs(prompt, o, validated, [])
valid = m[0]
reasons = m[1]
except Exception as ex:
Expand Down Expand Up @@ -664,16 +680,20 @@ def validate_prompt(prompt):

if len(good_outputs) == 0:
errors_list = []
extra_info = {}
for o, errors in errors:
for error in errors:
errors_list.append(f"{error['message']}: {error['details']}")
if errors:
extra_info[o] = []
for error in errors:
errors_list.append(f"{error['message']}: {error['details']}")
extra_info[o].append(error.get('extra_info',""))
errors_list = "\n".join(errors_list)

error = {
"type": "prompt_outputs_failed_validation",
"message": "Prompt outputs failed validation",
"details": errors_list,
"extra_info": {}
"extra_info": extra_info,
}

return (False, error, list(good_outputs), node_errors)
Expand Down