Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart
from gokart.task_complete_check import task_complete_check_wrapper
from gokart.utils import FlattenableItems, flatten, map_flattenable_items
from gokart.utils import FlattenableItems, flatten, get_dataframe_type_from_task, map_flattenable_items

logger = getLogger(__name__)

Expand Down Expand Up @@ -219,6 +219,10 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool
file_path = os.path.join(self.workspace_directory, formatted_relative_file_path)
unique_id = self.make_unique_id() if use_unique_id else None

# Auto-select processor based on type parameter if not provided
if processor is None and relative_file_path is not None:
processor = self._create_processor_for_dataframe_type(file_path)

task_lock_params = make_task_lock_params(
file_path=file_path,
unique_id=unique_id,
Expand All @@ -232,6 +236,38 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
)

def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor | None:
"""
Create a file processor with appropriate return_type based on task's type parameter.

Args:
file_path: Path to the file

Returns:
FileProcessor with return_type set, or None to use default processor
"""
from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor, ParquetFileProcessor

extension = os.path.splitext(file_path)[1]
df_type = get_dataframe_type_from_task(self)

# Create custom processor for DataFrame-supporting file types with type parameter
if extension == '.csv':
return CsvFileProcessor(sep=',', dataframe_type=df_type)
elif extension == '.tsv':
return CsvFileProcessor(sep='\t', dataframe_type=df_type)
elif extension == '.json':
return JsonFileProcessor(orient=None, dataframe_type=df_type)
elif extension == '.ndjson':
return JsonFileProcessor(orient='records', dataframe_type=df_type)
elif extension == '.parquet':
return ParquetFileProcessor(dataframe_type=df_type)
elif extension == '.feather':
return FeatherFileProcessor(store_index_in_feather=self.store_index_in_feather, dataframe_type=df_type)

# For other file types, use default processor selection
return None

def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:
formatted_relative_file_path = (
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip')
Expand Down
47 changes: 46 additions & 1 deletion gokart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from collections.abc import Callable, Iterable
from io import BytesIO
from typing import Any, Protocol, TypeAlias, TypeVar
from typing import Any, Literal, Protocol, TypeAlias, TypeVar, get_args, get_origin

import dill
import luigi
Expand Down Expand Up @@ -92,3 +92,48 @@ def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> An
assert file.seekable(), f'{file} is not seekable.'
file.seek(0)
return pd.read_pickle(file)


def get_dataframe_type_from_task(task: Any) -> Literal['pandas', 'polars']:
"""
Extract DataFrame type from TaskOnKart[T] type parameter.

Examines the type parameter T of a TaskOnKart subclass to determine
whether it uses pandas or polars DataFrames.

Args:
task: A TaskOnKart instance or class

Returns:
'pandas' or 'polars' (defaults to 'pandas' if type cannot be determined)

Examples:
>>> class MyTask(TaskOnKart[pd.DataFrame]): pass
>>> get_dataframe_type_from_task(MyTask())
'pandas'

>>> class MyPolarsTask(TaskOnKart[pl.DataFrame]): pass
>>> get_dataframe_type_from_task(MyPolarsTask())
'polars'
"""
task_class = task if isinstance(task, type) else task.__class__

if not hasattr(task_class, '__orig_bases__'):
return 'pandas'

for base in task_class.__orig_bases__:
origin = get_origin(base)
# Check if this is a TaskOnKart subclass
if origin and hasattr(origin, '__name__') and origin.__name__ == 'TaskOnKart':
args = get_args(base)
if args:
df_type = args[0]
module = getattr(df_type, '__module__', '')

# Check module name to determine DataFrame type
if 'polars' in module:
return 'polars'
elif 'pandas' in module:
return 'pandas'

return 'pandas' # Default to pandas for backward compatibility
110 changes: 109 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import unittest

from gokart.utils import flatten, map_flattenable_items
import pandas as pd
import pytest

from gokart.task import TaskOnKart
from gokart.utils import flatten, get_dataframe_type_from_task, map_flattenable_items

try:
import polars as pl

HAS_POLARS = True
except ImportError:
HAS_POLARS = False


class TestFlatten(unittest.TestCase):
Expand Down Expand Up @@ -34,3 +45,100 @@ def test_map_flattenable_items(self):
),
{'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}},
)


class TestGetDataFrameTypeFromTask(unittest.TestCase):
"""Tests for get_dataframe_type_from_task function."""

def test_pandas_dataframe_from_instance(self):
"""Test detecting pandas DataFrame from task instance."""

class PandasTask(TaskOnKart[pd.DataFrame]):
pass

task = PandasTask()
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

def test_pandas_dataframe_from_class(self):
"""Test detecting pandas DataFrame from task class."""

class PandasTask(TaskOnKart[pd.DataFrame]):
pass

self.assertEqual(get_dataframe_type_from_task(PandasTask), 'pandas')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_dataframe_from_instance(self):
"""Test detecting polars DataFrame from task instance."""

class PolarsTask(TaskOnKart[pl.DataFrame]):
pass

task = PolarsTask()
self.assertEqual(get_dataframe_type_from_task(task), 'polars')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_dataframe_from_class(self):
"""Test detecting polars DataFrame from task class."""

class PolarsTask(TaskOnKart[pl.DataFrame]):
pass

self.assertEqual(get_dataframe_type_from_task(PolarsTask), 'polars')

def test_no_type_parameter_defaults_to_pandas(self):
"""Test that tasks without type parameter default to pandas."""

# Create a class without __orig_bases__ by not using type parameters
class PlainTask:
pass

task = PlainTask()
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

def test_non_taskonkart_class_defaults_to_pandas(self):
"""Test that non-TaskOnKart classes default to pandas."""

class RegularClass:
pass

task = RegularClass()
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

def test_taskonkart_with_non_dataframe_type(self):
"""Test TaskOnKart with non-DataFrame type parameter defaults to pandas."""

class StringTask(TaskOnKart[str]):
pass

task = StringTask()
# Should default to pandas since str module is not 'pandas' or 'polars'
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

def test_nested_inheritance_pandas(self):
"""Test that nested inheritance without direct type parameter defaults to pandas."""

class BasePandasTask(TaskOnKart[pd.DataFrame]):
pass

class DerivedPandasTask(BasePandasTask):
pass

task = DerivedPandasTask()
# DerivedPandasTask doesn't have its own __orig_bases__ with type parameter,
# so it defaults to 'pandas'
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')

@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_nested_inheritance_polars(self):
"""Test detecting polars DataFrame type through nested inheritance."""

class BasePolarsTask(TaskOnKart[pl.DataFrame]):
pass

class DerivedPolarsTask(BasePolarsTask):
pass

task = DerivedPolarsTask()
# Function should detect 'polars' through the inheritance chain
self.assertEqual(get_dataframe_type_from_task(task), 'polars')
Loading