Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,9 @@ def _build_schema(self, inferred_schema: pa.Schema):

def _build_writer(self, inferred_schema: pa.Schema):
self._schema, self._features = self._build_schema(inferred_schema)
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema)
self.pa_writer = pa.RecordBatchStreamWriter(
self.stream, self._schema, options=pa.ipc.IpcWriteOptions(allow_64bit=True)
)

@property
def schema(self):
Expand Down
8 changes: 6 additions & 2 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,9 +1535,13 @@ def list_of_pa_arrays_to_pyarrow_listarray(l_arr: list[Optional[pa.Array]]) -> p
[0] + [len(arr) for arr in l_arr], dtype=object
) # convert to dtype object to allow None insertion
offsets = np.insert(offsets, null_indices, None)
offsets = pa.array(offsets, type=pa.int32())
values = pa.concat_arrays(l_arr)
return pa.ListArray.from_arrays(offsets, values)
try:
offsets = pa.array(offsets, type=pa.int32())
return pa.ListArray.from_arrays(offsets, values)
except pa.lib.ArrowInvalid:
offsets = pa.array(offsets, type=pa.int64())
return pa.LargeListArray.from_arrays(offsets, values)


def list_of_np_array_to_pyarrow_listarray(l_arr: list[np.ndarray], type: pa.DataType = None) -> pa.ListArray:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4783,3 +4783,23 @@ def test_from_polars_save_to_disk_and_load_from_disk_round_trip_with_large_list(
def test_polars_round_trip():
ds = Dataset.from_dict({"x": [[1, 2], [3, 4, 5]], "y": ["a", "b"]})
assert isinstance(Dataset.from_polars(ds.to_polars()), Dataset)


def test_map_int32_overflow():
# GH: 7821
def process_batch(batch):
res = []
for _ in batch["id"]:
res.append(np.zeros((2**31)).astype(np.uint16))

return {"audio": res}

ds = Dataset.from_dict({"id": [0]})
mapped_ds = ds.map(
process_batch,
batched=True,
batch_size=1,
num_proc=0,
remove_columns=ds.column_names,
)
assert isinstance(mapped_ds, Dataset)