Skip to content

Commit

Permalink
Fix arbitrary file write during tarfile extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali-Razmjoo committed Sep 6, 2024
1 parent c9a0d20 commit d7bccbb
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 10 deletions.
8 changes: 3 additions & 5 deletions luigi/contrib/lsf_runner.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except ImportError:
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def do_work_on_compute_node(work_dir):
Expand All @@ -52,10 +52,8 @@ def extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
8 changes: 3 additions & 5 deletions luigi/contrib/sge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import sys
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def _do_work_on_compute_node(work_dir, tarball=True):
Expand Down Expand Up @@ -64,10 +64,8 @@ def _extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
97 changes: 97 additions & 0 deletions luigi/safe_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module provides a class `SafeExtractor` that offers a secure way to extract tar files while
mitigating path traversal vulnerabilities, which can occur when files inside the archive are
crafted to escape the intended extraction directory.
The `SafeExtractor` ensures that the extracted file paths are validated before extraction to
prevent malicious archives from extracting files outside the intended directory.
Classes:
SafeExtractor: A class to securely extract tar files with protection against path traversal attacks.
Usage Example:
extractor = SafeExtractor("/desired/directory")
extractor.safe_extract("archive.tar")
"""

import os
import tarfile


class SafeExtractor:
"""
A class to safely extract tar files, ensuring that no path traversal
vulnerabilities are exploited.
Attributes:
path (str): The directory to extract files into.
Methods:
_is_within_directory(directory, target):
Checks if a target path is within a given directory.
safe_extract(tar_path, members=None, \\*, numeric_owner=False):
Safely extracts the contents of a tar file to the specified directory.
"""

def __init__(self, path="."):
"""
Initializes the SafeExtractor with the specified directory path.
Args:
path (str): The directory to extract files into. Defaults to the current directory.
"""
self.path = path

@staticmethod
def _is_within_directory(directory, target):
"""
Checks if a target path is within a given directory.
Args:
directory (str): The directory to check against.
target (str): The target path to check.
Returns:
bool: True if the target path is within the directory, False otherwise.
"""
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

def safe_extract(self, tar_path, members=None, *, numeric_owner=False):
"""
Safely extracts the contents of a tar file to the specified directory.
Args:
tar_path (str): The path to the tar file to extract.
members (list, optional): A list of members to extract. Defaults to None.
numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False.
Raises:
ValueError: If a path traversal attempt is detected.
"""
with tarfile.open(tar_path, 'r') as tar:
for member in tar.getmembers():
member_path = os.path.join(self.path, member.name)
if not self._is_within_directory(self.path, member_path):
raise ValueError("Attempted Path Traversal in Tar File")
tar.extractall(self.path, members, numeric_owner=numeric_owner)
116 changes: 116 additions & 0 deletions test/safe_extractor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Safe Extractor Test
=============
Tests for the Safe Extractor class in luigi.safe_extractor module.
"""

import os
import shutil
import tarfile
import tempfile
import unittest

from luigi.safe_extractor import SafeExtractor


class TestSafeExtract(unittest.TestCase):
"""
Unit test class for testing the SafeExtractor module.
"""

def setUp(self):
"""Set up temporary directory for test files."""
self.temp_dir = tempfile.mkdtemp()
self.test_file_template = 'test_file_{}.txt'
self.tar_file_name = 'test.tar'

def tearDown(self):
"""Clean up the temporary directory after each test."""
shutil.rmtree(self.temp_dir)

def create_tar(self, tar_path, file_count=1, file_contents=None, archive_name=None, with_traversal=False):
"""
Helper method to create a tar file with test data.
Args:
tar_path (str): Path to the tar file to be created.
file_count (int): Number of test files to include in the tar.
file_contents (list): Optional list of file contents.
archive_name (str): Optional name for files added to tar.
with_traversal (bool): If True, creates a path traversal tarball.
"""
file_contents = file_contents or [f'This is {self.test_file_template.format(i)}' for i in range(file_count)]
with tarfile.open(tar_path, 'w') as tar:
for i in range(file_count):
file_name = self.test_file_template.format(i)
file_path = os.path.join(self.temp_dir, file_name)
with open(file_path, 'w') as f:
f.write(file_contents[i])

# Handle path traversal if requested
archive_name = archive_name or (f'../../{file_name}' if with_traversal else file_name)
tar.add(file_path, arcname=archive_name)

def verify_extracted_files(self, file_count):
"""
Helper method to verify files extracted from tar.
Args:
file_count (int): Number of files to verify.
"""
for i in range(file_count):
file_name = self.test_file_template.format(i)
file_path = os.path.join(self.temp_dir, file_name)
self.assertTrue(os.path.exists(file_path), f"File {file_name} does not exist.")
with open(file_path, 'r') as f:
content = f.read()
self.assertEqual(content, f'This is {file_name}', f"Content mismatch in {file_name}.")

def test_safe_extract(self):
"""Test normal safe extraction of tar files."""
self.run_safe_extract_test()

def test_safe_extract_with_traversal(self):
"""Test safe extraction for tar files with path traversal."""
self.run_safe_extract_test(with_traversal=True, expect_error=True)

def run_safe_extract_test(self, with_traversal=False, expect_error=False):
"""
Run a safe extract test with optional path traversal.
Args:
with_traversal (bool): If True, test for path traversal.
expect_error (bool): If True, expect an error during extraction.
"""
tar_path = os.path.join(self.temp_dir, self.tar_file_name)
self.create_tar(tar_path, file_count=3 if not with_traversal else 1, with_traversal=with_traversal)

extractor = SafeExtractor(self.temp_dir)
if expect_error:
with self.assertRaises(ValueError):
extractor.safe_extract(tar_path)
else:
extractor.safe_extract(tar_path)
self.verify_extracted_files(3)


if __name__ == '__main__':
unittest.main()

0 comments on commit d7bccbb

Please sign in to comment.