Skip to content

Commit

Permalink
Fix arbitrary file write during tarfile extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali-Razmjoo committed Sep 6, 2024
1 parent c9a0d20 commit b012a00
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 10 deletions.
8 changes: 3 additions & 5 deletions luigi/contrib/lsf_runner.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except ImportError:
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def do_work_on_compute_node(work_dir):
Expand All @@ -52,10 +52,8 @@ def extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
8 changes: 3 additions & 5 deletions luigi/contrib/sge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import sys
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def _do_work_on_compute_node(work_dir, tarball=True):
Expand Down Expand Up @@ -64,10 +64,8 @@ def _extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
97 changes: 97 additions & 0 deletions luigi/safe_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module provides a class `SafeExtractor` that offers a secure way to extract tar files while
mitigating path traversal vulnerabilities, which can occur when files inside the archive are
crafted to escape the intended extraction directory.
The `SafeExtractor` ensures that the extracted file paths are validated before extraction to
prevent malicious archives from extracting files outside the intended directory.
Classes:
SafeExtractor: A class to securely extract tar files with protection against path traversal attacks.
Usage Example:
extractor = SafeExtractor("/desired/directory")
extractor.safe_extract("archive.tar")
"""

import os
import tarfile


class SafeExtractor:
"""
A class to safely extract tar files, ensuring that no path traversal
vulnerabilities are exploited.
Attributes:
path (str): The directory to extract files into.
Methods:
_is_within_directory(directory, target):
Checks if a target path is within a given directory.
safe_extract(tar_path, members=None, \\*, numeric_owner=False):
Safely extracts the contents of a tar file to the specified directory.
"""

def __init__(self, path="."):
"""
Initializes the SafeExtractor with the specified directory path.
Args:
path (str): The directory to extract files into. Defaults to the current directory.
"""
self.path = path

@staticmethod
def _is_within_directory(directory, target):
"""
Checks if a target path is within a given directory.
Args:
directory (str): The directory to check against.
target (str): The target path to check.
Returns:
bool: True if the target path is within the directory, False otherwise.
"""
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

def safe_extract(self, tar_path, members=None, *, numeric_owner=False):
"""
Safely extracts the contents of a tar file to the specified directory.
Args:
tar_path (str): The path to the tar file to extract.
members (list, optional): A list of members to extract. Defaults to None.
numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False.
Raises:
RuntimeError: If a path traversal attempt is detected.
"""
with tarfile.open(tar_path, 'r') as tar:
for member in tar.getmembers():
member_path = os.path.join(self.path, member.name)
if not self._is_within_directory(self.path, member_path):
raise RuntimeError("Attempted Path Traversal in Tar File")
tar.extractall(self.path, members, numeric_owner=numeric_owner)
125 changes: 125 additions & 0 deletions test/safe_extractor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Safe Extractor Test
=============
Tests for the Safe Extractor class in luigi.safe_extractor module.
"""

import os
import shutil
import tarfile
import tempfile
import unittest

from luigi.safe_extractor import SafeExtractor


class TestSafeExtract(unittest.TestCase):
"""
Unit test class for testing the SafeExtractor module.
"""

def setUp(self):
"""Set up a temporary directory for test files."""
self.temp_dir = tempfile.mkdtemp()
self.test_file_template = 'test_file_{}.txt'
self.tar_file_name = 'test.tar'
self.tar_file_name_with_traversal = f'traversal_{self.tar_file_name}'

def tearDown(self):
"""Clean up the temporary directory after each test."""
shutil.rmtree(self.temp_dir)

def create_test_tar(self, tar_path, file_count=1, with_traversal=False):
"""
Create a tar file containing test files.
Args:
tar_path (str): Path where the tar file will be created.
file_count (int): Number of test files to include.
with_traversal (bool): If True, creates a tar file with path traversal vulnerability.
"""
# Default content for the test files
file_contents = [f'This is {self.test_file_template.format(i)}' for i in range(file_count)]

with tarfile.open(tar_path, 'w') as tar:
for i in range(file_count):
file_name = self.test_file_template.format(i)
file_path = os.path.join(self.temp_dir, file_name)

# Write content to each test file
with open(file_path, 'w') as f:
f.write(file_contents[i])

# If path traversal is enabled, create malicious paths
archive_name = f'../../{file_name}' if with_traversal else file_name

# Add the file to the tar archive
tar.add(file_path, arcname=archive_name)

def verify_extracted_files(self, file_count):
"""
Verify that the correct files were extracted and their contents match expectations.
Args:
file_count (int): Number of files to verify.
"""
for i in range(file_count):
file_name = self.test_file_template.format(i)
file_path = os.path.join(self.temp_dir, file_name)

# Check if the file exists
self.assertTrue(os.path.exists(file_path), f"File {file_name} does not exist.")

# Check if the file content is correct
with open(file_path, 'r') as f:
content = f.read()
expected_content = f'This is {file_name}'
self.assertEqual(content, expected_content, f"Content mismatch in {file_name}.")

def test_safe_extract(self):
"""Test normal safe extraction of tar files."""
tar_path = os.path.join(self.temp_dir, self.tar_file_name)

# Create a tar file with 3 files
self.create_test_tar(tar_path, file_count=3)

# Initialize SafeExtractor and perform extraction
extractor = SafeExtractor(self.temp_dir)
extractor.safe_extract(tar_path)

# Verify that all 3 files were extracted correctly
self.verify_extracted_files(3)

def test_safe_extract_with_traversal(self):
"""Test safe extraction for tar files with path traversal (should raise an error)."""
tar_path = os.path.join(self.temp_dir, self.tar_file_name_with_traversal)

# Create a tar file with a path traversal file
self.create_test_tar(tar_path, file_count=1, with_traversal=True)

# Initialize SafeExtractor and expect RuntimeError due to path traversal
extractor = SafeExtractor(self.temp_dir)
with self.assertRaises(RuntimeError):
extractor.safe_extract(tar_path)


if __name__ == '__main__':
unittest.main()

0 comments on commit b012a00

Please sign in to comment.