Skip to content

Commit

Permalink
Fix arbitrary file write during tarfile extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali-Razmjoo committed Sep 4, 2024
1 parent 74e6e63 commit 5850c44
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 10 deletions.
8 changes: 3 additions & 5 deletions luigi/contrib/lsf_runner.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except ImportError:
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def do_work_on_compute_node(work_dir):
Expand All @@ -52,10 +52,8 @@ def extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
8 changes: 3 additions & 5 deletions luigi/contrib/sge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import sys
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def _do_work_on_compute_node(work_dir, tarball=True):
Expand Down Expand Up @@ -64,10 +64,8 @@ def _extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
96 changes: 96 additions & 0 deletions luigi/safe_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module provides a class `SafeExtractor` that offers a secure way to extract tar files while
mitigating path traversal vulnerabilities, which can occur when files inside the archive are
crafted to escape the intended extraction directory.
The `SafeExtractor` ensures that the extracted file paths are validated before extraction to
prevent malicious archives from extracting files outside the intended directory.
Classes:
SafeExtractor: A class to securely extract tar files with protection against path traversal attacks.
Usage Example:
extractor = SafeExtractor("/desired/directory")
extractor.safe_extract("archive.tar")
"""

import os
import tarfile


class SafeExtractor:
"""
A class to safely extract tar files, ensuring that no path traversal
vulnerabilities are exploited.
Attributes:
path (str): The directory to extract files into.
Methods:
_is_within_directory(directory, target):
Checks if a target path is within a given directory.
safe_extract(tar_path, members=None, *, numeric_owner=False):
Safely extracts the contents of a tar file to the specified directory.
"""

def __init__(self, path="."):
"""
Initializes the SafeExtractor with the specified directory path.
Args:
path (str): The directory to extract files into. Defaults to the current directory.
"""
self.path = path

def _is_within_directory(self, directory, target):
"""
Checks if a target path is within a given directory.
Args:
directory (str): The directory to check against.
target (str): The target path to check.
Returns:
bool: True if the target path is within the directory, False otherwise.
"""
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

def safe_extract(self, tar_path, members=None, *, numeric_owner=False):
"""
Safely extracts the contents of a tar file to the specified directory.
Args:
tar_path (str): The path to the tar file to extract.
members (list, optional): A list of members to extract. Defaults to None.
numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False.
Raises:
ValueError: If a path traversal attempt is detected.
"""
with tarfile.open(tar_path, 'r') as tar:
for member in tar.getmembers():
member_path = os.path.join(self.path, member.name)
if not self._is_within_directory(self.path, member_path):
raise ValueError("Attempted Path Traversal in Tar File")
tar.extractall(self.path, members, numeric_owner=numeric_owner)
44 changes: 44 additions & 0 deletions test/contrib/lsf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@

import luigi
from luigi.contrib.lsf import LSFJobTask
import tarfile
import tempfile
import shutil

import pytest
from luigi.safe_extractor import SafeExtractor

DEFAULT_HOME = ''

Expand Down Expand Up @@ -103,5 +107,45 @@ def tearDown(self):
pass


class TestSafeExtract(unittest.TestCase):

def setUp(self):
self.temp_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.temp_dir)

def test_safe_extract(self):
tar_path = os.path.join(self.temp_dir, 'test.tar')
with tarfile.open(tar_path, 'w') as tar:
for i in range(3):
file_path = os.path.join(self.temp_dir, f'test_file_{i}.txt')
with open(file_path, 'w') as f:
f.write(f'This is test file {i}')
tar.add(file_path, arcname=f'test_file_{i}.txt')

extractor = SafeExtractor(self.temp_dir)
extractor.safe_extract(tar_path)

for i in range(3):
file_path = os.path.join(self.temp_dir, f'test_file_{i}.txt')
self.assertTrue(os.path.exists(file_path))
with open(file_path, 'r') as f:
content = f.read()
self.assertEqual(content, f'This is test file {i}')

def test_safe_extract_with_traversal(self):
tar_path = os.path.join(self.temp_dir, 'test.tar')
with tarfile.open(tar_path, 'w') as tar:
file_path = os.path.join(self.temp_dir, 'test_file.txt')
with open(file_path, 'w') as f:
f.write('This is a test file')
tar.add(file_path, arcname='../../test_file.txt')

extractor = SafeExtractor(self.temp_dir)
with self.assertRaises(ValueError):
extractor.safe_extract(tar_path)


if __name__ == '__main__':
unittest.main()

0 comments on commit 5850c44

Please sign in to comment.