Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow polars as valid output type #6762

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add polars as a return type in tests
  • Loading branch information
psmyth94 committed Aug 15, 2024
commit 858d5b0263bff59ca007d90c7b8fe19c775a47f8
100 changes: 87 additions & 13 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,10 @@ def test_flatten(self, in_memory):
with Dataset.from_dict(
{"a": [{"en": "the cat", "fr": ["le chat", "la chatte"], "de": "die katze"}] * 10, "foo": [1] * 10},
features=Features(
{"a": TranslationVariableLanguages(languages=["en", "fr", "de"]), "foo": Value("int64")}
{
"a": TranslationVariableLanguages(languages=["en", "fr", "de"]),
"foo": Value("int64"),
}
),
) as dset:
with self._to(in_memory, tmp_dir, dset) as dset:
Expand Down Expand Up @@ -1000,7 +1003,11 @@ def test_flatten_complex_image(self, in_memory):
self.assertDictEqual(
dset.features,
Features(
{"a.b.bytes": Value("binary"), "a.b.path": Value("string"), "foo": Value("int64")}
{
"a.b.bytes": Value("binary"),
"a.b.path": Value("string"),
"foo": Value("int64"),
}
),
)
self.assertNotEqual(dset._fingerprint, fingerprint)
Expand Down Expand Up @@ -1528,6 +1535,47 @@ def func_return_multi_row_pd_dataframe(x):
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe)

def test_map_return_pl_dataframe(self, in_memory):
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
import polars as pl

def func_return_single_row_pl_dataframe(x):
return pl.DataFrame({"id": [0], "text": ["a"]})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with dset.map(func_return_single_row_pl_dataframe) as dset_test:
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"id": Value("int64"), "text": Value("large_string")}),
)
self.assertEqual(dset_test[0]["id"], 0)
self.assertEqual(dset_test[0]["text"], "a")

# Batched
def func_return_single_row_pl_dataframe_batched(x):
batch_size = len(x[next(iter(x))])
return pl.DataFrame({"id": [0] * batch_size, "text": ["a"] * batch_size})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with dset.map(func_return_single_row_pl_dataframe_batched, batched=True) as dset_test:
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"id": Value("int64"), "text": Value("large_string")}),
)
self.assertEqual(dset_test[0]["id"], 0)
self.assertEqual(dset_test[0]["text"], "a")

# Error when returning a table with more than one row in the non-batched mode
def func_return_multi_row_pl_dataframe(x):
return pl.DataFrame({"id": [0, 1], "text": ["a", "b"]})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
self.assertRaises(ValueError, dset.map, func_return_multi_row_pl_dataframe)

@require_numpy1_on_windows
@require_torch
def test_map_torch(self, in_memory):
Expand Down Expand Up @@ -1831,7 +1879,10 @@ def test_filter_caching(self, in_memory):

def test_keep_features_after_transform_specified(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1849,7 +1900,10 @@ def invert_labels(x):

def test_keep_features_after_transform_unspecified(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1867,7 +1921,10 @@ def invert_labels(x):

def test_keep_features_after_transform_to_file(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1886,7 +1943,10 @@ def invert_labels(x):

def test_keep_features_after_transform_to_memory(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1903,7 +1963,10 @@ def invert_labels(x):

def test_keep_features_after_loading_from_cache(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1926,7 +1989,10 @@ def invert_labels(x):

def test_keep_features_with_new_features(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand Down Expand Up @@ -3710,7 +3776,11 @@ def test_dataset_from_json_features(features, jsonl_path, tmp_path):

def test_dataset_from_json_with_class_label_feature(jsonl_str_path, tmp_path):
features = Features(
{"col_1": ClassLabel(names=["s0", "s1", "s2", "s3"]), "col_2": Value("int64"), "col_3": Value("float64")}
{
"col_1": ClassLabel(names=["s0", "s1", "s2", "s3"]),
"col_2": Value("int64"),
"col_3": Value("float64"),
}
)
cache_dir = tmp_path / "cache"
dataset = Dataset.from_json(jsonl_str_path, features=features, cache_dir=cache_dir)
Expand Down Expand Up @@ -4262,7 +4332,11 @@ def test_task_question_answering(self):
def test_task_summarization(self):
# Include a dummy extra column `dummy` to test we drop it correctly
features_before_cast = Features(
{"input_text": Value("string"), "input_summary": Value("string"), "dummy": Value("string")}
{
"input_text": Value("string"),
"input_summary": Value("string"),
"dummy": Value("string"),
}
)
features_after_cast = Features({"text": Value("string"), "summary": Value("string")})
task = Summarization(text_column="input_text", summary_column="input_summary")
Expand Down Expand Up @@ -4882,7 +4956,7 @@ def test_dataset_batch():
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]

# Check last partial batch
assert len(batches[3]["id"]) == 1
Expand All @@ -4899,7 +4973,7 @@ def test_dataset_batch():
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]

# Test with batch_size=4 (doesn't evenly divide dataset size)
batched_ds = ds.batch(batch_size=4, drop_last_batch=False)
Expand All @@ -4910,7 +4984,7 @@ def test_dataset_batch():
assert len(batch["id"]) == 4
assert len(batch["text"]) == 4
assert batch["id"] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
assert batch["text"] == [f"Text {4*i}", f"Text {4*i+1}", f"Text {4*i+2}", f"Text {4*i+3}"]
assert batch["text"] == [f"Text {4 * i}", f"Text {4 * i + 1}", f"Text {4 * i + 2}", f"Text {4 * i + 3}"]

# Check last partial batch
assert len(batches[2]["id"]) == 2
Expand Down
Loading