Skip to content

Commit

Permalink
fix(python): Handle current position of file objects
Browse files Browse the repository at this point in the history
  • Loading branch information
ruihe774 committed Jul 10, 2024
1 parent daf2e49 commit 366c935
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 28 deletions.
26 changes: 15 additions & 11 deletions crates/polars-io/src/utils/other.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(any(feature = "ipc_streaming", feature = "parquet"))]
use std::borrow::Cow;
use std::io::Read;
use std::fs::File;
use std::io::{Read, Seek};

use once_cell::sync::Lazy;
use polars_core::prelude::*;
Expand All @@ -10,17 +11,20 @@ use regex::{Regex, RegexBuilder};

use crate::mmap::{MmapBytesReader, ReaderBytes};

pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>(
reader: &'a mut R,
) -> PolarsResult<ReaderBytes<'a>> {
pub fn get_reader_bytes<R: Read + MmapBytesReader + ?Sized>(
reader: &mut R,
) -> PolarsResult<ReaderBytes> {
// we have a file so we can mmap
if let Some(file) = reader.to_file() {
let mmap = unsafe { memmap::Mmap::map(file)? };

// somehow bck thinks borrows alias
// this is sound as file was already bound to 'a
use std::fs::File;
let file = unsafe { std::mem::transmute::<&File, &'a File>(file) };
// only seekable files are mmap-able
if let Some((file, offset)) = reader.to_file().and_then(|file| {
// stream_position() requires &mut (why), fake it here
#[allow(mutable_transmutes)]
let file = unsafe { std::mem::transmute::<&File, &mut File>(file) };
let offset = file.stream_position().ok()?;
let file: &File = file;
Some((file, offset))
}) {
let mmap = unsafe { memmap::MmapOptions::new().offset(offset).map(file)? };
Ok(ReaderBytes::Mapped(mmap, file))
} else {
// we can get the bytes for free
Expand Down
28 changes: 14 additions & 14 deletions py-polars/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,19 +211,22 @@ fn get_either_file_and_path(
let encoding = encoding.extract::<Cow<str>>()?;
Ok(encoding.eq_ignore_ascii_case("utf-8") || encoding.eq_ignore_ascii_case("utf8"))
};
let flush_file = |py_f: &Bound<PyAny>| -> PyResult<()> {
py_f.getattr("flush")?.call0()?;
Ok(())
};
#[cfg(target_family = "unix")]
if let Some(fd) = ((py_f.is_exact_instance(&io.getattr("FileIO").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedReader").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedWriter").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedRandom").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedRWPair").unwrap())
|| (py_f.is_exact_instance(&io.getattr("TextIOWrapper").unwrap())
&& is_utf8_encoding(&py_f)?))
&& (!write || flush_file(&py_f).is_ok()))
|| (py_f.is_exact_instance(&io.getattr("BufferedReader").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedWriter").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedRandom").unwrap())
|| py_f.is_exact_instance(&io.getattr("BufferedRWPair").unwrap())
|| (py_f.is_exact_instance(&io.getattr("TextIOWrapper").unwrap())
&& is_utf8_encoding(&py_f)?))
&& (
// invalidate read buffer
write || py_f.call_method1("seek", (0, 1)).is_ok()
))
&& (
// flush write buffer
!write || py_f.call_method0("flush").is_ok()
))
.then(|| {
py_f.getattr("fileno")
.and_then(|fileno| fileno.call0())
Expand Down Expand Up @@ -261,9 +264,6 @@ fn get_either_file_and_path(
)
.into());
}
if write {
flush_file(&py_f)?;
}
py_f.getattr("buffer")?
} else {
py_f
Expand Down
8 changes: 5 additions & 3 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,10 +2240,12 @@ def test_write_csv_raise_on_non_utf8_17328(


@pytest.mark.write_disk()
def test_write_csv_appending_17328(tmp_path: Path) -> None:
def test_write_csv_appending_17543(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
df = pl.DataFrame({"col": ["value"]})
with (tmp_path / "append.csv").open("w") as f:
f.write("# test\n")
pl.DataFrame({"col": ["value"]}).write_csv(f)
df.write_csv(f)
with (tmp_path / "append.csv").open("r") as f:
assert f.read() == "# test\ncol\nvalue\n"
assert f.readline() == "# test\n"
assert pl.read_csv(f).equals(df)

0 comments on commit 366c935

Please sign in to comment.