Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REFACTOR-#2739: io tests refactoring #2740

Merged
merged 2 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions modin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
import os
import sys
import pytest
import pandas
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import shutil

import modin
import modin.config
Expand All @@ -23,6 +28,17 @@
from modin.engines.python.pandas_on_python.io import PandasOnPythonIO
from modin.data_management.factories import factories
from modin.utils import get_current_backend
from modin.pandas.test.utils import (
_make_csv_file,
get_unique_filename,
teardown_test_files,
NROWS,
IO_OPS_DATA_DIR,
)

# create test data dir if it is not exists yet
if not os.path.exists(IO_OPS_DATA_DIR):
os.mkdir(IO_OPS_DATA_DIR)


def pytest_addoption(parser):
Expand Down Expand Up @@ -232,3 +248,168 @@ def pytest_runtest_call(item):
**marker.kwargs,
)
)


@pytest.fixture(scope="class")
def TestReadCSVFixture():
filenames = []
files_ids = [
"test_read_csv_regular",
"test_read_csv_blank_lines",
"test_read_csv_yes_no",
"test_read_csv_nans",
"test_read_csv_bad_lines",
]
# each xdist worker spawned in separate process with separate namespace and dataset
pytest.csvs_names = {file_id: get_unique_filename() for file_id in files_ids}
# test_read_csv_col_handling, test_read_csv_parsing
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_regular"],
)
# test_read_csv_parsing
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_yes_no"],
additional_col_values=["Yes", "true", "No", "false"],
)
# test_read_csv_col_handling
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_blank_lines"],
add_blank_lines=True,
)
# test_read_csv_nans_handling
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_nans"],
add_blank_lines=True,
additional_col_values=["<NA>", "N/A", "NA", "NULL", "custom_nan", "73"],
)
# test_read_csv_error_handling
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_bad_lines"],
add_bad_lines=True,
)

yield
# Delete csv files that were created
teardown_test_files(filenames)


@pytest.fixture
def make_csv_file():
"""Pytest fixture factory that makes temp csv files for testing.
Yields:
Function that generates csv files
"""
filenames = []

yield _make_csv_file(filenames)

# Delete csv files that were created
teardown_test_files(filenames)


@pytest.fixture
def make_parquet_file():
"""Pytest fixture factory that makes a parquet file/dir for testing.

Yields:
Function that generates a parquet file/dir
"""
filenames = []

def _make_parquet_file(
filename,
row_size=NROWS,
force=True,
directory=False,
partitioned_columns=[],
):
"""Helper function to generate parquet files/directories.

Args:
filename: The name of test file, that should be created.
row_size: Number of rows for the dataframe.
force: Create a new file/directory even if one already exists.
directory: Create a partitioned directory using pyarrow.
partitioned_columns: Create a partitioned directory using pandas.
Will be ignored if directory=True.
"""
df = pandas.DataFrame(
{"col1": np.arange(row_size), "col2": np.arange(row_size)}
)
if os.path.exists(filename) and not force:
pass
elif directory:
if os.path.exists(filename):
shutil.rmtree(filename)
else:
os.mkdir(filename)
table = pa.Table.from_pandas(df)
pq.write_to_dataset(table, root_path=filename)
elif len(partitioned_columns) > 0:
df.to_parquet(filename, partition_cols=partitioned_columns)
else:
df.to_parquet(filename)

filenames.append(filename)

# Return function that generates csv files
yield _make_parquet_file

# Delete parquet file that was created
for path in filenames:
if os.path.exists(path):
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)


@pytest.fixture
def make_sql_connection():
"""Sets up sql connections and takes them down after the caller is done.

Yields:
Factory that generates sql connection objects
"""
filenames = []

def _sql_connection(filename, table=""):
# Remove file if exists
if os.path.exists(filename):
os.remove(filename)
filenames.append(filename)
# Create connection and, if needed, table
conn = "sqlite:///{}".format(filename)
if table:
df = pandas.DataFrame(
{
"col1": [0, 1, 2, 3, 4, 5, 6],
"col2": [7, 8, 9, 10, 11, 12, 13],
"col3": [14, 15, 16, 17, 18, 19, 20],
"col4": [21, 22, 23, 24, 25, 26, 27],
"col5": [0, 0, 0, 0, 0, 0, 0],
}
)
df.to_sql(table, conn)
return conn

yield _sql_connection

# Teardown the fixture
teardown_test_files(filenames)


@pytest.fixture(scope="class")
def TestReadGlobCSVFixture():
filenames = []

base_name = get_unique_filename(extension="")
pytest.glob_path = "{}_*.csv".format(base_name)
pytest.files = ["{}_{}.csv".format(base_name, i) for i in range(11)]
for fname in pytest.files:
# Glob does not guarantee ordering so we have to remove the randomness in the generated csvs.
_make_csv_file(filenames)(fname, row_size=11, remove_randomness=True)

yield

teardown_test_files(filenames)
25 changes: 1 addition & 24 deletions modin/experimental/pandas/test/test_io_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
import pytest
import modin.experimental.pandas as pd
from modin.config import Engine
from modin.pandas.test.test_io import ( # noqa: F401
df_equals,
eval_io,
make_sql_connection,
_make_csv_file,
teardown_test_files,
)
from modin.pandas.test.utils import get_unique_filename
from modin.pandas.test.utils import df_equals


@pytest.mark.skipif(
Expand Down Expand Up @@ -69,22 +62,6 @@ def test_from_sql_defaults(make_sql_connection): # noqa: F811
df_equals(modin_df_from_table, pandas_df)


@pytest.fixture(scope="class")
def TestReadGlobCSVFixture():
filenames = []

base_name = get_unique_filename(extension="")
pytest.glob_path = "{}_*.csv".format(base_name)
pytest.files = ["{}_{}.csv".format(base_name, i) for i in range(11)]
for fname in pytest.files:
# Glob does not guarantee ordering so we have to remove the randomness in the generated csvs.
_make_csv_file(filenames)(fname, row_size=11, remove_randomness=True)

yield

teardown_test_files(filenames)


@pytest.mark.usefixtures("TestReadGlobCSVFixture")
@pytest.mark.skipif(
Engine.get() != "Ray", reason="Currently only support Ray engine for glob paths."
Expand Down
Loading