Skip to content

Commit

Permalink
Retry diff logic
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOsika committed May 2, 2024
1 parent a0794bf commit 247dc4c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 23 deletions.
2 changes: 1 addition & 1 deletion gpt_engineer/core/default/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
MAX_EDIT_REFINEMENT_STEPS : int
The maximum number of refinement steps allowed when generating edit blocks.
"""
MAX_EDIT_REFINEMENT_STEPS = 0
MAX_EDIT_REFINEMENT_STEPS = 2
39 changes: 17 additions & 22 deletions gpt_engineer/core/default/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,34 +313,29 @@ def improve_fn(
def _improve_loop(
ai: AI, files_dict: FilesDict, memory: BaseMemory, messages: List
) -> FilesDict:
problems = []
messages = ai.next(messages, step_name=curr_fn())

# check edit correctness
edit_refinements = 0
while edit_refinements <= MAX_EDIT_REFINEMENT_STEPS:
messages = ai.next(messages, step_name=curr_fn())
files_dict = salvage_correct_hunks(messages, files_dict, problems, memory)

# if len(problems) > 0:
# messages.append(
# HumanMessage(
# content="Some previously produced diffs were not on the requested format, or the code part was not found in the code. Details: "
# + "\n".join(problems)
# + "\n Only rewrite the problematic diffs, making sure that the failing ones are now on the correct format and can be found in the code. Make sure to not repeat past mistakes. \n"
# )
# )
# messages = ai.next(messages, step_name=curr_fn())
# edit_refinements += 1
# files_dict = salvage_correct_hunks(messages, files_dict, problems)
for _ in range(MAX_EDIT_REFINEMENT_STEPS):
files_dict, errors = salvage_correct_hunks(messages, files_dict, memory)
if not errors:
break
messages.append(
HumanMessage(
content="Some previously produced diffs were not on the requested format, or the code part was not found in the code. Details:\n"
+ "\n".join(errors)
+ "\n Only rewrite the problematic diffs, making sure that the failing ones are now on the correct format and can be found in the code. Make sure to not repeat past mistakes. \n"
)
)
return files_dict


def salvage_correct_hunks(
messages: List,
files_dict: FilesDict,
error_message: List,
memory: BaseMemory,
) -> FilesDict:
) -> tuple[FilesDict, List[str]]:
error_messages = []
ai_response = messages[-1].content.strip()

diffs = parse_diffs(ai_response)
Expand All @@ -352,11 +347,11 @@ def salvage_correct_hunks(
problems = diff.validate_and_correct(
file_to_lines_dict(files_dict[diff.filename_pre])
)
error_message.extend(problems)
error_messages.extend(problems)
files_dict = apply_diffs(diffs, files_dict)
memory.log(IMPROVE_LOG_FILE, "\n\n".join(x.pretty_repr() for x in messages))
memory.log(DIFF_LOG_FILE, "\n\n".join(error_message))
return files_dict
memory.log(DIFF_LOG_FILE, "\n\n".join(error_messages))
return files_dict, error_messages


class Tee(object):
Expand Down

0 comments on commit 247dc4c

Please sign in to comment.