Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions pystreamapi/loaders/__csv_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from collections import namedtuple
from csv import reader

from pystreamapi.loaders.__lazy_file_iterable import LazyFileIterable

def csv(file_path: str, delimiter=',', encoding="utf-8") -> list:

def csv(file_path: str, delimiter=',', encoding="utf-8") -> LazyFileIterable:
"""
Loads a CSV file and converts it into a list of namedtuples.

Expand All @@ -15,27 +17,29 @@ def csv(file_path: str, delimiter=',', encoding="utf-8") -> list:
:param delimiter: The delimiter used in the CSV file.
"""
file_path = __validate_path(file_path)
return LazyFileIterable(lambda: __load_csv(file_path, delimiter, encoding))


def __load_csv(file_path, delimiter, encoding):
"""Load a CSV file and convert it into a list of namedtuples"""
# skipcq: PTC-W6004
with open(file_path, 'r', newline='', encoding=encoding) as csvfile:
with open(file_path, mode='r', newline='', encoding=encoding) as csvfile:
csvreader = reader(csvfile, delimiter=delimiter)

# Create a namedtuple type, casting the header values to int or float if possible
Row = namedtuple('Row', list(next(csvreader, [])))

# Process the data, casting values to int or float if possible
data = [Row(*[__try_cast(value) for value in row]) for row in csvreader]

return data


def __validate_path(file_path: str):
"""Validate a path string to prevent path traversal attacks"""
if not os.path.isabs(file_path):
raise ValueError("The file_path must be an absolute path.")

"""Validate the path to the CSV file"""
if not os.path.exists(file_path):
raise FileNotFoundError("The specified file does not exist.")

if not os.path.isfile(file_path):
raise ValueError("The specified path is not a file.")
return file_path


Expand Down
23 changes: 23 additions & 0 deletions pystreamapi/loaders/__lazy_file_iterable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class LazyFileIterable:
"""LazyFileIterable is an iterable that loads data from a data source lazily."""

def __init__(self, loader):
self.__loader = loader
self.__data = None

def __iter__(self):
self.__load_data()
return iter(self.__data)

def __getitem__(self, index):
self.__load_data()
return self.__data[index]

def __len__(self):
self.__load_data()
return len(self.__data)

def __load_data(self):
"""Loads the data from the data source if it has not been loaded yet."""
if self.__data is None:
self.__data = self.__loader()
10 changes: 8 additions & 2 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from unittest import TestCase

from pystreamapi.loaders import csv


class TestLoaders(TestCase):

def setUp(self) -> None:
Expand All @@ -18,6 +20,10 @@ def test_csv_loader(self):
self.assertEqual(data[1].attr1, 'a')
self.assertIsInstance(data[1].attr1, str)

def test_csv_loader_is_iterable(self):
data = csv(f'{self.path}/data.csv')
self.assertEqual(len(list(iter(data))), 2)

def test_csv_loader_with_custom_delimiter(self):
data = csv(f'{self.path}/data2.csv', delimiter=';')
self.assertEqual(len(data), 1)
Expand All @@ -32,6 +38,6 @@ def test_csv_loader_with_invalid_path(self):
with self.assertRaises(FileNotFoundError):
csv(f'{self.path}/invalid.csv')

def test_csv_loader_with_non_absolute_path(self):
def test_csv_loader_with_non_file(self):
with self.assertRaises(ValueError):
csv('invalid.csv')
csv(f'{self.path}/')