Skip to content

Commit e2ec4b6

Browse files
committed
Add std::io::Seek instance for std::io::Take
1 parent f9e0239 commit e2ec4b6

File tree

3 files changed

+199
-1
lines changed

3 files changed

+199
-1
lines changed

library/std/src/io/mod.rs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ pub trait Read {
11731173
where
11741174
Self: Sized,
11751175
{
1176-
Take { inner: self, limit }
1176+
Take { inner: self, len: limit, limit }
11771177
}
11781178
}
11791179

@@ -2822,6 +2822,7 @@ impl<T, U> SizeHint for Chain<T, U> {
28222822
#[derive(Debug)]
28232823
pub struct Take<T> {
28242824
inner: T,
2825+
len: u64,
28252826
limit: u64,
28262827
}
28272828

@@ -2856,6 +2857,12 @@ impl<T> Take<T> {
28562857
self.limit
28572858
}
28582859

2860+
/// Returns the number of bytes read so far.
2861+
#[unstable(feature = "seek_io_take_position", issue = "97227")]
2862+
pub fn position(&self) -> u64 {
2863+
self.len - self.limit
2864+
}
2865+
28592866
/// Sets the number of bytes that can be read before this instance will
28602867
/// return EOF. This is the same as constructing a new `Take` instance, so
28612868
/// the amount of bytes read and the previous limit value don't matter when
@@ -2881,6 +2888,7 @@ impl<T> Take<T> {
28812888
/// ```
28822889
#[stable(feature = "take_set_limit", since = "1.27.0")]
28832890
pub fn set_limit(&mut self, limit: u64) {
2891+
self.len = limit;
28842892
self.limit = limit;
28852893
}
28862894

@@ -3064,6 +3072,49 @@ impl<T> SizeHint for Take<T> {
30643072
}
30653073
}
30663074

3075+
#[stable(feature = "seek_io_take", since = "CURRENT_RUSTC_VERSION")]
3076+
impl<T: Seek> Seek for Take<T> {
3077+
fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
3078+
let new_position = match pos {
3079+
SeekFrom::Start(v) => Some(v),
3080+
SeekFrom::Current(v) => self.position().checked_add_signed(v),
3081+
SeekFrom::End(v) => self.len.checked_add_signed(v),
3082+
};
3083+
let new_position = match new_position {
3084+
Some(v) if v <= self.len => v,
3085+
_ => return Err(ErrorKind::InvalidInput.into()),
3086+
};
3087+
while new_position != self.position() {
3088+
if let Some(offset) = new_position.checked_signed_diff(self.position()) {
3089+
self.inner.seek_relative(offset)?;
3090+
self.limit = self.limit.wrapping_sub(offset as u64);
3091+
break;
3092+
}
3093+
let offset = if new_position > self.position() { i64::MAX } else { i64::MIN };
3094+
self.inner.seek_relative(offset)?;
3095+
self.limit = self.limit.wrapping_sub(offset as u64);
3096+
}
3097+
Ok(new_position)
3098+
}
3099+
3100+
fn stream_len(&mut self) -> Result<u64> {
3101+
Ok(self.len)
3102+
}
3103+
3104+
fn stream_position(&mut self) -> Result<u64> {
3105+
Ok(self.position())
3106+
}
3107+
3108+
fn seek_relative(&mut self, offset: i64) -> Result<()> {
3109+
if !self.position().checked_add_signed(offset).is_some_and(|p| p <= self.len) {
3110+
return Err(ErrorKind::InvalidInput.into());
3111+
}
3112+
self.inner.seek_relative(offset)?;
3113+
self.limit = self.limit.wrapping_sub(offset as u64);
3114+
Ok(())
3115+
}
3116+
}
3117+
30673118
/// An iterator over `u8` values of a reader.
30683119
///
30693120
/// This struct is generally created by calling [`bytes`] on a reader.

library/std/src/io/tests.rs

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,152 @@ fn seek_position() -> io::Result<()> {
416416
Ok(())
417417
}
418418

419+
#[test]
420+
fn take_seek() -> io::Result<()> {
421+
let mut buf = Cursor::new(b"0123456789");
422+
buf.set_position(2);
423+
let mut take = buf.by_ref().take(4);
424+
let mut buf1 = [0u8; 1];
425+
let mut buf2 = [0u8; 2];
426+
assert_eq!(take.position(), 0);
427+
428+
assert_eq!(take.seek(SeekFrom::Start(0))?, 0);
429+
take.read_exact(&mut buf2)?;
430+
assert_eq!(buf2, [b'2', b'3']);
431+
assert_eq!(take.seek(SeekFrom::Start(1))?, 1);
432+
take.read_exact(&mut buf2)?;
433+
assert_eq!(buf2, [b'3', b'4']);
434+
assert_eq!(take.seek(SeekFrom::Start(2))?, 2);
435+
take.read_exact(&mut buf2)?;
436+
assert_eq!(buf2, [b'4', b'5']);
437+
assert_eq!(take.seek(SeekFrom::Start(3))?, 3);
438+
take.read_exact(&mut buf1)?;
439+
assert_eq!(buf1, [b'5']);
440+
assert_eq!(take.seek(SeekFrom::Start(4))?, 4);
441+
442+
assert_eq!(take.seek(SeekFrom::End(0))?, 4);
443+
assert_eq!(take.seek(SeekFrom::End(-1))?, 3);
444+
take.read_exact(&mut buf1)?;
445+
assert_eq!(buf1, [b'5']);
446+
assert_eq!(take.seek(SeekFrom::End(-2))?, 2);
447+
take.read_exact(&mut buf2)?;
448+
assert_eq!(buf2, [b'4', b'5']);
449+
assert_eq!(take.seek(SeekFrom::End(-3))?, 1);
450+
take.read_exact(&mut buf2)?;
451+
assert_eq!(buf2, [b'3', b'4']);
452+
assert_eq!(take.seek(SeekFrom::End(-4))?, 0);
453+
take.read_exact(&mut buf2)?;
454+
assert_eq!(buf2, [b'2', b'3']);
455+
456+
assert_eq!(take.seek(SeekFrom::Current(0))?, 2);
457+
take.read_exact(&mut buf2)?;
458+
assert_eq!(buf2, [b'4', b'5']);
459+
460+
assert_eq!(take.seek(SeekFrom::Current(-3))?, 1);
461+
take.read_exact(&mut buf2)?;
462+
assert_eq!(buf2, [b'3', b'4']);
463+
464+
assert_eq!(take.seek(SeekFrom::Current(-1))?, 2);
465+
take.read_exact(&mut buf2)?;
466+
assert_eq!(buf2, [b'4', b'5']);
467+
468+
assert_eq!(take.seek(SeekFrom::Current(-4))?, 0);
469+
take.read_exact(&mut buf2)?;
470+
assert_eq!(buf2, [b'2', b'3']);
471+
472+
assert_eq!(take.seek(SeekFrom::Current(2))?, 4);
473+
Ok(())
474+
}
475+
476+
#[test]
477+
#[should_panic]
478+
fn take_seek_out_of_bounds_start() {
479+
let buf = Cursor::new(b"0123456789");
480+
let mut take = buf.take(2);
481+
take.seek(SeekFrom::Start(3)).unwrap();
482+
}
483+
484+
#[test]
485+
#[should_panic]
486+
fn take_seek_out_of_bounds_end_forward() {
487+
let buf = Cursor::new(b"0123456789");
488+
let mut take = buf.take(2);
489+
take.seek(SeekFrom::End(1)).unwrap();
490+
}
491+
492+
#[test]
493+
#[should_panic]
494+
fn take_seek_out_of_bounds_end_before_start() {
495+
let buf = Cursor::new(b"0123456789");
496+
let mut take = buf.take(2);
497+
take.seek(SeekFrom::End(-3)).unwrap();
498+
}
499+
500+
#[test]
501+
#[should_panic]
502+
fn take_seek_out_of_bounds_current_before_start() {
503+
let buf = Cursor::new(b"0123456789");
504+
let mut take = buf.take(2);
505+
take.seek(SeekFrom::Current(-1)).unwrap();
506+
}
507+
508+
#[test]
509+
#[should_panic]
510+
fn take_seek_out_of_bounds_current_after_end() {
511+
let buf = Cursor::new(b"0123456789");
512+
let mut take = buf.take(2);
513+
take.seek(SeekFrom::Current(3)).unwrap();
514+
}
515+
516+
struct ExampleHugeRangeOfZeroes {
517+
position: u64,
518+
}
519+
520+
impl Read for ExampleHugeRangeOfZeroes {
521+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
522+
let max = buf.len().min(usize::MAX);
523+
for i in 0..max {
524+
if self.position == u64::MAX {
525+
return Ok(i);
526+
}
527+
self.position += 1;
528+
buf[i] = 0;
529+
}
530+
Ok(max)
531+
}
532+
}
533+
534+
impl Seek for ExampleHugeRangeOfZeroes {
535+
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
536+
match pos {
537+
io::SeekFrom::Start(i) => self.position = i,
538+
io::SeekFrom::End(i) if i >= 0 => self.position = u64::MAX,
539+
io::SeekFrom::End(i) => self.position = self.position - i.unsigned_abs(),
540+
io::SeekFrom::Current(i) => {
541+
self.position = if i >= 0 {
542+
self.position.saturating_add(i.unsigned_abs())
543+
} else {
544+
self.position.saturating_sub(i.unsigned_abs())
545+
};
546+
}
547+
}
548+
Ok(self.position)
549+
}
550+
}
551+
552+
#[test]
553+
fn take_seek_big_offsets() -> io::Result<()> {
554+
let inner = ExampleHugeRangeOfZeroes { position: 1 };
555+
let mut take = inner.take(u64::MAX - 2);
556+
assert_eq!(take.seek(io::SeekFrom::Start(u64::MAX - 2))?, u64::MAX - 2);
557+
assert_eq!(take.inner.position, u64::MAX - 1);
558+
assert_eq!(take.seek(io::SeekFrom::Start(0))?, 0);
559+
assert_eq!(take.inner.position, 1);
560+
assert_eq!(take.seek(io::SeekFrom::End(-1))?, u64::MAX - 3);
561+
assert_eq!(take.inner.position, u64::MAX - 2);
562+
Ok(())
563+
}
564+
419565
// A simple example reader which uses the default implementation of
420566
// read_to_end.
421567
struct ExampleSliceReader<'a> {

library/std/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@
320320
#![feature(thread_local)]
321321
#![feature(try_blocks)]
322322
#![feature(type_alias_impl_trait)]
323+
#![feature(unsigned_signed_diff)]
323324
// tidy-alphabetical-end
324325
//
325326
// Library features (core):

0 commit comments

Comments
 (0)