Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,25 @@ def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]:
self._state_dict["batch_idx"] += 1
self._state_dict["num_chunks_since_previous_state"] += len(chunks_buffer)
self._state_dict["cropped_chunk_length"] = 0
yield new_key, pa.Table.from_batches(chunks_buffer)

if self.features:
expected_schema = pa.schema(self.features.type)
casted_chunks = []
for chunk in chunks_buffer:
try:
casted_chunks.append(chunk.cast(expected_schema))
except (pa.ArrowInvalid, pa.ArrowNotImplementedError):
casted_chunks.append(chunk)
yield new_key, pa.Table.from_batches(casted_chunks)
else:
# Unify schemas when no explicit features provided
if chunks_buffer:
unified_schema = pa.unify_schemas([chunk.schema for chunk in chunks_buffer])
casted_chunks = [chunk.cast(unified_schema) for chunk in chunks_buffer]
yield new_key, pa.Table.from_batches(casted_chunks, schema=unified_schema)
else:
yield new_key, pa.Table.from_batches(chunks_buffer)

keys_buffer = []
chunks_buffer = []
chunks_buffer_size = 0
Expand All @@ -580,7 +598,23 @@ def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]:
self._state_dict["batch_idx"] += 1
self._state_dict["num_chunks_since_previous_state"] += len(chunks_buffer)
self._state_dict["cropped_chunk_length"] = cropped_chunk_length
yield new_key, pa.Table.from_batches(chunks_buffer)

if self.features:
expected_schema = pa.schema(self.features.type)
casted_chunks = []
for chunk in chunks_buffer:
try:
casted_chunks.append(chunk.cast(expected_schema))
except (pa.ArrowInvalid, pa.ArrowNotImplementedError):
casted_chunks.append(chunk)
yield new_key, pa.Table.from_batches(casted_chunks)
else:
if chunks_buffer:
unified_schema = pa.unify_schemas([chunk.schema for chunk in chunks_buffer])
casted_chunks = [chunk.cast(unified_schema) for chunk in chunks_buffer]
yield new_key, pa.Table.from_batches(casted_chunks, schema=unified_schema)
else:
yield new_key, pa.Table.from_batches(chunks_buffer)
keys_buffer = [f"{key}[{cropped_chunk_length}:]"]
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)]
chunks_buffer_size = len(chunk) - cropped_chunk_length
Expand All @@ -596,7 +630,23 @@ def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]:
self._state_dict["batch_idx"] += 1
self._state_dict["num_chunks_since_previous_state"] = 0
self._state_dict["cropped_chunk_length"] = 0
yield new_key, pa.Table.from_batches(chunks_buffer)

if self.features:
expected_schema = pa.schema(self.features.type)
casted_chunks = []
for chunk in chunks_buffer:
try:
casted_chunks.append(chunk.cast(expected_schema))
except (pa.ArrowInvalid, pa.ArrowNotImplementedError):
casted_chunks.append(chunk)
yield new_key, pa.Table.from_batches(casted_chunks)
else:
if chunks_buffer:
unified_schema = pa.unify_schemas([chunk.schema for chunk in chunks_buffer])
casted_chunks = [chunk.cast(unified_schema) for chunk in chunks_buffer]
yield new_key, pa.Table.from_batches(casted_chunks, schema=unified_schema)
else:
yield new_key, pa.Table.from_batches(chunks_buffer)

def shuffle_data_sources(self, generator: np.random.Generator) -> "RebatchedArrowExamplesIterable":
return RebatchedArrowExamplesIterable(
Expand Down