88from patchwork .step import Step , StepStatus
99
1010
11- def save_file_contents (file_path , content ):
12- """Utility function to save content to a file."""
13- with open (file_path , "w" ) as file :
11+ def save_file_contents (file_path : str | Path , content : str ) -> None :
12+ """Utility function to save content to a file.
13+
14+ Args:
15+ file_path: Path to the file to save content to (str or Path)
16+ content: Content to write to the file
17+ """
18+ path = Path (file_path )
19+ with path .open ("w" ) as file :
1420 file .write (content )
1521
1622
@@ -36,20 +42,26 @@ def handle_indent(src: list[str], target: list[str], start: int, end: int) -> li
3642
3743
3844def replace_code_in_file (
39- file_path : str ,
45+ file_path : str | Path ,
4046 start_line : int | None ,
4147 end_line : int | None ,
4248 new_code : str ,
4349) -> None :
50+ """Replace code in a file at the specified line range.
51+
52+ Args:
53+ file_path: Path to the file to modify (str or Path)
54+ start_line: Starting line number (1-based)
55+ end_line: Ending line number (1-based)
56+ new_code: New code to insert
57+ """
4458 path = Path (file_path )
4559 new_code_lines = new_code .splitlines (keepends = True )
4660 if len (new_code_lines ) > 0 and not new_code_lines [- 1 ].endswith ("\n " ):
4761 new_code_lines [- 1 ] += "\n "
4862
4963 if path .exists () and start_line is not None and end_line is not None :
50- """Replaces specified lines in a file with new code."""
5164 text = path .read_text ()
52-
5365 lines = text .splitlines (keepends = True )
5466
5567 # Insert the new code at the start line after converting it into a list of lines
@@ -58,7 +70,7 @@ def replace_code_in_file(
5870 lines = new_code_lines
5971
6072 # Save the modified contents back to the file
61- save_file_contents (file_path , "" .join (lines ))
73+ save_file_contents (path , "" .join (lines ))
6274
6375
6476class ModifyCode (Step ):
@@ -84,7 +96,8 @@ def run(self) -> dict:
8496 return dict (modified_code_files = [])
8597
8698 for code_snippet , extracted_response in sorted_list :
87- uri = code_snippet .get ("uri" )
99+ # Use Path for consistent path handling
100+ file_path = Path (code_snippet .get ("uri" , "" ))
88101 start_line = code_snippet .get ("startLine" )
89102 end_line = code_snippet .get ("endLine" )
90103 new_code = extracted_response .get ("patch" )
@@ -93,41 +106,68 @@ def run(self) -> dict:
93106 continue
94107
95108 # Get the original content for diffing
96- file_path = Path (uri )
97- original_path = None
98109 diff = ""
99110
100111 if file_path .exists ():
101- # Create a temporary copy of the original file
102- with tempfile .NamedTemporaryFile (mode = 'w' , delete = False ) as tmp_file :
103- shutil .copy2 (uri , tmp_file .name )
104- original_path = tmp_file .name
105-
106- # Apply the changes
107- replace_code_in_file (uri , start_line , end_line , new_code )
108-
109- # Generate a proper unified diff
110- with open (original_path , 'r' ) as f1 , open (uri , 'r' ) as f2 :
111- diff = '' .join (difflib .unified_diff (
112- f1 .readlines (),
113- f2 .readlines (),
114- fromfile = 'a/' + str (file_path ),
115- tofile = 'b/' + str (file_path )
116- ))
117-
118- # Clean up temporary file
119- Path (original_path ).unlink ()
112+ try :
113+ # Create a temporary directory with restricted permissions
114+ with tempfile .TemporaryDirectory (prefix = 'modifycode_' ) as temp_dir :
115+ # Create temporary file path within the secure directory
116+ temp_path = Path (temp_dir ) / 'original_file'
117+
118+ # Copy original file with same permissions
119+ shutil .copy2 (file_path , temp_path )
120+
121+ # Store original content
122+ with temp_path .open ('r' ) as f1 :
123+ original_lines = f1 .readlines ()
124+
125+ # Apply the changes
126+ replace_code_in_file (file_path , start_line , end_line , new_code )
127+
128+ # Read modified content
129+ with file_path .open ('r' ) as f2 :
130+ modified_lines = f2 .readlines ()
131+
132+ # Generate a proper unified diff
133+ # Use Path for consistent path handling
134+ relative_path = str (file_path )
135+ diff = '' .join (difflib .unified_diff (
136+ original_lines ,
137+ modified_lines ,
138+ fromfile = str (Path ('a' ) / relative_path ),
139+ tofile = str (Path ('b' ) / relative_path )
140+ ))
141+
142+ # temp_dir and its contents are automatically cleaned up
143+ except (OSError , IOError ) as e :
144+ print (f"Warning: Failed to generate diff for { file_path } : { str (e )} " )
145+ # Still proceed with the modification even if diff generation fails
146+ replace_code_in_file (file_path , start_line , end_line , new_code )
120147 else :
121148 # If file doesn't exist, just store the new code as the diff
122- diff = f"+++ { file_path } \n { new_code } "
149+ # Use Path for consistent path handling
150+ relative_path = str (file_path )
151+ diff = f"+++ { Path (relative_path )} \n { new_code } "
123152
153+ # Create and validate the modified code file dictionary
124154 modified_code_file = dict (
125- path = uri ,
155+ path = str ( file_path ) ,
126156 start_line = start_line ,
127157 end_line = end_line ,
128158 diff = diff ,
129159 ** extracted_response
130160 )
161+
162+ # Ensure all required fields are present with correct types
163+ if not isinstance (modified_code_file ["path" ], str ):
164+ raise TypeError (f"path must be str, got { type (modified_code_file ['path' ])} " )
165+ if not isinstance (modified_code_file ["start_line" ], (int , type (None ))):
166+ raise TypeError (f"start_line must be int or None, got { type (modified_code_file ['start_line' ])} " )
167+ if not isinstance (modified_code_file ["end_line" ], (int , type (None ))):
168+ raise TypeError (f"end_line must be int or None, got { type (modified_code_file ['end_line' ])} " )
169+ if not isinstance (modified_code_file ["diff" ], str ):
170+ raise TypeError (f"diff must be str, got { type (modified_code_file ['diff' ])} " )
131171 modified_code_files .append (modified_code_file )
132172
133173 return dict (modified_code_files = modified_code_files )
0 commit comments