|
13 | 13 |
|
14 | 14 | import tempfile |
15 | 15 | import unittest |
16 | | -import os |
17 | 16 | import zipfile |
18 | 17 | import tarfile |
19 | 18 | from pathlib import Path |
@@ -71,181 +70,181 @@ def test_default(self, key, file_type): |
71 | 70 |
|
72 | 71 | class TestPathTraversalProtection(unittest.TestCase): |
73 | 72 | """Test cases for path traversal attack protection in extractall function.""" |
74 | | - |
| 73 | + |
75 | 74 | def test_valid_zip_extraction(self): |
76 | 75 | """Test that valid zip files extract successfully without raising exceptions.""" |
77 | 76 | with tempfile.TemporaryDirectory() as tmp_dir: |
78 | 77 | # Create a valid zip file |
79 | 78 | zip_path = Path(tmp_dir) / "valid_test.zip" |
80 | 79 | extract_dir = Path(tmp_dir) / "extract" |
81 | 80 | extract_dir.mkdir() |
82 | | - |
| 81 | + |
83 | 82 | # Create zip with normal file structure |
84 | 83 | with zipfile.ZipFile(zip_path, 'w') as zf: |
85 | 84 | zf.writestr("normal_file.txt", "This is a normal file") |
86 | 85 | zf.writestr("subdir/nested_file.txt", "This is a nested file") |
87 | 86 | zf.writestr("another_file.json", '{"key": "value"}') |
88 | | - |
| 87 | + |
89 | 88 | # This should not raise any exception |
90 | 89 | try: |
91 | 90 | extractall(str(zip_path), str(extract_dir)) |
92 | | - |
| 91 | + |
93 | 92 | # Verify files were extracted correctly |
94 | 93 | self.assertTrue((extract_dir / "normal_file.txt").exists()) |
95 | 94 | self.assertTrue((extract_dir / "subdir" / "nested_file.txt").exists()) |
96 | 95 | self.assertTrue((extract_dir / "another_file.json").exists()) |
97 | | - |
| 96 | + |
98 | 97 | # Verify content |
99 | | - with open(extract_dir / "normal_file.txt", 'r') as f: |
| 98 | + with open(extract_dir / "normal_file.txt") as f: |
100 | 99 | self.assertEqual(f.read(), "This is a normal file") |
101 | | - |
| 100 | + |
102 | 101 | except Exception as e: |
103 | 102 | self.fail(f"Valid zip extraction should not raise exception: {e}") |
104 | | - |
| 103 | + |
105 | 104 | def test_malicious_zip_path_traversal(self): |
106 | 105 | """Test that malicious zip files with path traversal attempts raise ValueError.""" |
107 | 106 | with tempfile.TemporaryDirectory() as tmp_dir: |
108 | 107 | # Create malicious zip file with path traversal |
109 | 108 | zip_path = Path(tmp_dir) / "malicious_test.zip" |
110 | 109 | extract_dir = Path(tmp_dir) / "extract" |
111 | 110 | extract_dir.mkdir() |
112 | | - |
| 111 | + |
113 | 112 | # Create zip with malicious paths |
114 | 113 | with zipfile.ZipFile(zip_path, 'w') as zf: |
115 | 114 | # Try to write outside extraction directory |
116 | 115 | zf.writestr("../../../etc/malicious.txt", "malicious content") |
117 | 116 | zf.writestr("normal_file.txt", "normal content") |
118 | | - |
| 117 | + |
119 | 118 | # This should raise ValueError due to path traversal detection |
120 | 119 | with self.assertRaises(ValueError) as context: |
121 | 120 | extractall(str(zip_path), str(extract_dir)) |
122 | | - |
| 121 | + |
123 | 122 | self.assertIn("unsafe path", str(context.exception).lower()) |
124 | | - |
| 123 | + |
125 | 124 | def test_valid_tar_extraction(self): |
126 | 125 | """Test that valid tar files extract successfully without raising exceptions.""" |
127 | 126 | with tempfile.TemporaryDirectory() as tmp_dir: |
128 | 127 | # Create a valid tar file |
129 | 128 | tar_path = Path(tmp_dir) / "valid_test.tar.gz" |
130 | 129 | extract_dir = Path(tmp_dir) / "extract" |
131 | 130 | extract_dir.mkdir() |
132 | | - |
| 131 | + |
133 | 132 | # Create tar with normal file structure |
134 | 133 | with tarfile.open(tar_path, 'w:gz') as tf: |
135 | 134 | # Create temporary files to add to tar |
136 | 135 | temp_file1 = Path(tmp_dir) / "temp1.txt" |
137 | 136 | temp_file1.write_text("This is a normal file") |
138 | 137 | tf.add(temp_file1, arcname="normal_file.txt") |
139 | | - |
| 138 | + |
140 | 139 | temp_file2 = Path(tmp_dir) / "temp2.txt" |
141 | 140 | temp_file2.write_text("This is a nested file") |
142 | 141 | tf.add(temp_file2, arcname="subdir/nested_file.txt") |
143 | | - |
| 142 | + |
144 | 143 | # This should not raise any exception |
145 | 144 | try: |
146 | 145 | extractall(str(tar_path), str(extract_dir)) |
147 | | - |
| 146 | + |
148 | 147 | # Verify files were extracted correctly |
149 | 148 | self.assertTrue((extract_dir / "normal_file.txt").exists()) |
150 | 149 | self.assertTrue((extract_dir / "subdir" / "nested_file.txt").exists()) |
151 | | - |
| 150 | + |
152 | 151 | # Verify content |
153 | | - with open(extract_dir / "normal_file.txt", 'r') as f: |
| 152 | + with open(extract_dir / "normal_file.txt") as f: |
154 | 153 | self.assertEqual(f.read(), "This is a normal file") |
155 | | - |
| 154 | + |
156 | 155 | except Exception as e: |
157 | 156 | self.fail(f"Valid tar extraction should not raise exception: {e}") |
158 | | - |
| 157 | + |
159 | 158 | def test_malicious_tar_path_traversal(self): |
160 | 159 | """Test that malicious tar files with path traversal attempts raise ValueError.""" |
161 | 160 | with tempfile.TemporaryDirectory() as tmp_dir: |
162 | 161 | # Create malicious tar file with path traversal |
163 | 162 | tar_path = Path(tmp_dir) / "malicious_test.tar.gz" |
164 | 163 | extract_dir = Path(tmp_dir) / "extract" |
165 | 164 | extract_dir.mkdir() |
166 | | - |
| 165 | + |
167 | 166 | # Create tar with malicious paths |
168 | 167 | with tarfile.open(tar_path, 'w:gz') as tf: |
169 | 168 | # Create a temporary file |
170 | 169 | temp_file = Path(tmp_dir) / "temp.txt" |
171 | 170 | temp_file.write_text("malicious content") |
172 | | - |
| 171 | + |
173 | 172 | # Add with malicious path (trying to write outside extraction directory) |
174 | 173 | tf.add(temp_file, arcname="../../../etc/malicious.txt") |
175 | | - |
| 174 | + |
176 | 175 | # This should raise ValueError due to path traversal detection |
177 | 176 | with self.assertRaises(ValueError) as context: |
178 | 177 | extractall(str(tar_path), str(extract_dir)) |
179 | | - |
| 178 | + |
180 | 179 | self.assertIn("unsafe path", str(context.exception).lower()) |
181 | | - |
| 180 | + |
182 | 181 | def test_absolute_path_protection(self): |
183 | 182 | """Test protection against absolute paths in archives.""" |
184 | 183 | with tempfile.TemporaryDirectory() as tmp_dir: |
185 | 184 | # Create zip with absolute path |
186 | 185 | zip_path = Path(tmp_dir) / "absolute_path_test.zip" |
187 | 186 | extract_dir = Path(tmp_dir) / "extract" |
188 | 187 | extract_dir.mkdir() |
189 | | - |
| 188 | + |
190 | 189 | with zipfile.ZipFile(zip_path, 'w') as zf: |
191 | 190 | # Try to use absolute path |
192 | 191 | zf.writestr("/etc/passwd", "malicious content") |
193 | | - |
| 192 | + |
194 | 193 | # This should raise ValueError due to absolute path detection |
195 | 194 | with self.assertRaises(ValueError) as context: |
196 | 195 | extractall(str(zip_path), str(extract_dir)) |
197 | | - |
| 196 | + |
198 | 197 | self.assertIn("unsafe path", str(context.exception).lower()) |
199 | 198 |
|
200 | | - def test_malicious_symlink_protection(self): |
201 | | - """Test protection against malicious symlinks in tar archives.""" |
202 | | - with tempfile.TemporaryDirectory() as tmp_dir: |
203 | | - # Create malicious tar file with symlink |
204 | | - tar_path = Path(tmp_dir) / "malicious_symlink_test.tar.gz" |
205 | | - extract_dir = Path(tmp_dir) / "extract" |
206 | | - extract_dir.mkdir() |
207 | | - |
208 | | - # Create tar with malicious symlink |
209 | | - with tarfile.open(tar_path, 'w:gz') as tf: |
210 | | - temp_file = Path(tmp_dir) / "normal.txt" |
211 | | - temp_file.write_text("normal content") |
212 | | - tf.add(temp_file, arcname="normal.txt") |
213 | | - |
214 | | - symlink_info = tarfile.TarInfo(name="malicious_symlink.txt") |
215 | | - symlink_info.type = tarfile.SYMTYPE |
216 | | - symlink_info.linkname = "../../../etc/passwd" |
217 | | - symlink_info.size = 0 |
218 | | - tf.addfile(symlink_info) |
219 | | - |
220 | | - with self.assertRaises(ValueError) as context: |
221 | | - extractall(str(tar_path), str(extract_dir)) |
222 | | - |
| 199 | + def test_malicious_symlink_protection(self): |
| 200 | + """Test protection against malicious symlinks in tar archives.""" |
| 201 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 202 | + # Create malicious tar file with symlink |
| 203 | + tar_path = Path(tmp_dir) / "malicious_symlink_test.tar.gz" |
| 204 | + extract_dir = Path(tmp_dir) / "extract" |
| 205 | + extract_dir.mkdir() |
| 206 | + |
| 207 | + # Create tar with malicious symlink |
| 208 | + with tarfile.open(tar_path, 'w:gz') as tf: |
| 209 | + temp_file = Path(tmp_dir) / "normal.txt" |
| 210 | + temp_file.write_text("normal content") |
| 211 | + tf.add(temp_file, arcname="normal.txt") |
| 212 | + |
| 213 | + symlink_info = tarfile.TarInfo(name="malicious_symlink.txt") |
| 214 | + symlink_info.type = tarfile.SYMTYPE |
| 215 | + symlink_info.linkname = "../../../etc/passwd" |
| 216 | + symlink_info.size = 0 |
| 217 | + tf.addfile(symlink_info) |
| 218 | + |
| 219 | + with self.assertRaises(ValueError) as context: |
| 220 | + extractall(str(tar_path), str(extract_dir)) |
| 221 | + |
223 | 222 | error_msg = str(context.exception).lower() |
224 | 223 | self.assertTrue("unsafe path" in error_msg or "symlink" in error_msg) |
225 | | - |
226 | | - def test_malicious_hardlink_protection(self): |
227 | | - """Test protection against malicious hard links in tar archives.""" |
228 | | - with tempfile.TemporaryDirectory() as tmp_dir: |
229 | | - # Create malicious tar file with hard link |
230 | | - tar_path = Path(tmp_dir) / "malicious_hardlink_test.tar.gz" |
231 | | - extract_dir = Path(tmp_dir) / "extract" |
232 | | - extract_dir.mkdir() |
233 | | - |
234 | | - # Create tar with malicious hard link |
235 | | - with tarfile.open(tar_path, 'w:gz') as tf: |
236 | | - temp_file = Path(tmp_dir) / "normal.txt" |
237 | | - temp_file.write_text("normal content") |
238 | | - tf.add(temp_file, arcname="normal.txt") |
239 | | - |
240 | | - hardlink_info = tarfile.TarInfo(name="malicious_hardlink.txt") |
241 | | - hardlink_info.type = tarfile.LNKTYPE |
242 | | - hardlink_info.linkname = "/etc/passwd" |
243 | | - hardlink_info.size = 0 |
244 | | - tf.addfile(hardlink_info) |
245 | | - |
246 | | - with self.assertRaises(ValueError) as context: |
247 | | - extractall(str(tar_path), str(extract_dir)) |
248 | | - |
| 224 | + |
| 225 | + def test_malicious_hardlink_protection(self): |
| 226 | + """Test protection against malicious hard links in tar archives.""" |
| 227 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 228 | + # Create malicious tar file with hard link |
| 229 | + tar_path = Path(tmp_dir) / "malicious_hardlink_test.tar.gz" |
| 230 | + extract_dir = Path(tmp_dir) / "extract" |
| 231 | + extract_dir.mkdir() |
| 232 | + |
| 233 | + # Create tar with malicious hard link |
| 234 | + with tarfile.open(tar_path, 'w:gz') as tf: |
| 235 | + temp_file = Path(tmp_dir) / "normal.txt" |
| 236 | + temp_file.write_text("normal content") |
| 237 | + tf.add(temp_file, arcname="normal.txt") |
| 238 | + |
| 239 | + hardlink_info = tarfile.TarInfo(name="malicious_hardlink.txt") |
| 240 | + hardlink_info.type = tarfile.LNKTYPE |
| 241 | + hardlink_info.linkname = "/etc/passwd" |
| 242 | + hardlink_info.size = 0 |
| 243 | + tf.addfile(hardlink_info) |
| 244 | + |
| 245 | + with self.assertRaises(ValueError) as context: |
| 246 | + extractall(str(tar_path), str(extract_dir)) |
| 247 | + |
249 | 248 | error_msg = str(context.exception).lower() |
250 | 249 | self.assertTrue("unsafe path" in error_msg or "hardlink" in error_msg) |
251 | 250 |
|
|
0 commit comments