Skip to content

Commit 1847d32

Browse files
committed
Switched to storing mz_stream as a raw pointer to fix tree borrows violation.
Removed Deref and DerefMut implementations for StreamWrapper.
1 parent f0463d5 commit 1847d32

File tree

2 files changed

+100
-89
lines changed

2 files changed

+100
-89
lines changed

src/ffi/c.rs

Lines changed: 89 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use std::cmp;
44
use std::convert::TryFrom;
55
use std::fmt;
66
use std::marker;
7-
use std::ops::{Deref, DerefMut};
87
use std::os::raw::{c_int, c_uint, c_void};
98
use std::ptr;
109

@@ -21,7 +20,7 @@ impl ErrorMessage {
2120
}
2221

2322
pub struct StreamWrapper {
24-
pub inner: Box<mz_stream>,
23+
pub inner: *mut mz_stream,
2524
}
2625

2726
impl fmt::Debug for StreamWrapper {
@@ -32,8 +31,11 @@ impl fmt::Debug for StreamWrapper {
3231

3332
impl Default for StreamWrapper {
3433
fn default() -> StreamWrapper {
34+
// We need to store the mz_stream object as a raw pointer since
35+
// a cyclic structure is created in the `state` field, which points
36+
// back to the mz_stream object.
3537
StreamWrapper {
36-
inner: Box::new(mz_stream {
38+
inner: Box::into_raw(Box::new(mz_stream {
3739
next_in: ptr::null_mut(),
3840
avail_in: 0,
3941
total_in: 0,
@@ -54,7 +56,15 @@ impl Default for StreamWrapper {
5456
zalloc: Some(zalloc),
5557
#[cfg(not(all(feature = "any_zlib", not(feature = "cloudflare-zlib-sys"))))]
5658
zfree: Some(zfree),
57-
}),
59+
})),
60+
}
61+
}
62+
}
63+
64+
impl Drop for StreamWrapper {
65+
fn drop(&mut self) {
66+
unsafe {
67+
drop(Box::from_raw(self.inner));
5868
}
5969
}
6070
}
@@ -110,20 +120,6 @@ extern "C" fn zfree(_ptr: *mut c_void, address: *mut c_void) {
110120
}
111121
}
112122

113-
impl Deref for StreamWrapper {
114-
type Target = mz_stream;
115-
116-
fn deref(&self) -> &Self::Target {
117-
&*self.inner
118-
}
119-
}
120-
121-
impl DerefMut for StreamWrapper {
122-
fn deref_mut(&mut self) -> &mut Self::Target {
123-
&mut *self.inner
124-
}
125-
}
126-
127123
unsafe impl<D: Direction> Send for Stream<D> {}
128124
unsafe impl<D: Direction> Sync for Stream<D> {}
129125

@@ -148,7 +144,7 @@ pub struct Stream<D: Direction> {
148144

149145
impl<D: Direction> Stream<D> {
150146
pub fn msg(&self) -> ErrorMessage {
151-
let msg = self.stream_wrapper.msg;
147+
let msg = unsafe { (*self.stream_wrapper.inner).msg };
152148
ErrorMessage(if msg.is_null() {
153149
None
154150
} else {
@@ -161,7 +157,7 @@ impl<D: Direction> Stream<D> {
161157
impl<D: Direction> Drop for Stream<D> {
162158
fn drop(&mut self) {
163159
unsafe {
164-
let _ = D::destroy(&mut *self.stream_wrapper);
160+
let _ = D::destroy(self.stream_wrapper.inner);
165161
}
166162
}
167163
}
@@ -185,9 +181,9 @@ pub struct Inflate {
185181
impl InflateBackend for Inflate {
186182
fn make(zlib_header: bool, window_bits: u8) -> Self {
187183
unsafe {
188-
let mut state = StreamWrapper::default();
184+
let state = StreamWrapper::default();
189185
let ret = mz_inflateInit2(
190-
&mut *state,
186+
state.inner,
191187
if zlib_header {
192188
window_bits as c_int
193189
} else {
@@ -212,33 +208,40 @@ impl InflateBackend for Inflate {
212208
output: &mut [u8],
213209
flush: FlushDecompress,
214210
) -> Result<Status, DecompressError> {
215-
let raw = &mut *self.inner.stream_wrapper;
216-
raw.msg = ptr::null_mut();
217-
raw.next_in = input.as_ptr() as *mut u8;
218-
raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
219-
raw.next_out = output.as_mut_ptr();
220-
raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;
221-
222-
let rc = unsafe { mz_inflate(raw, flush as c_int) };
223-
224-
// Unfortunately the total counters provided by zlib might be only
225-
// 32 bits wide and overflow while processing large amounts of data.
226-
self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64;
227-
self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64;
228-
229-
// reset these pointers so we don't accidentally read them later
230-
raw.next_in = ptr::null_mut();
231-
raw.avail_in = 0;
232-
raw.next_out = ptr::null_mut();
233-
raw.avail_out = 0;
234-
235-
match rc {
236-
MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()),
237-
MZ_OK => Ok(Status::Ok),
238-
MZ_BUF_ERROR => Ok(Status::BufError),
239-
MZ_STREAM_END => Ok(Status::StreamEnd),
240-
MZ_NEED_DICT => mem::decompress_need_dict(raw.adler as u32),
241-
c => panic!("unknown return code: {}", c),
211+
let raw = self.inner.stream_wrapper.inner;
212+
// We need to access the `inner` field of the `StreamWrapper` object
213+
// as a raw pointer here since the field `state` in `mz_stream` is
214+
// a pointer back to the `mz_stream` object. Any mutable borrow against
215+
// inner will become invalidated by `mz_inflate`, leading to an invalid
216+
// dereference after that function returns.
217+
unsafe {
218+
(*raw).msg = ptr::null_mut();
219+
(*raw).next_in = input.as_ptr() as *mut u8;
220+
(*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
221+
(*raw).next_out = output.as_mut_ptr();
222+
(*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;
223+
224+
let rc = mz_inflate(raw, flush as c_int);
225+
226+
// Unfortunately the total counters provided by zlib might be only
227+
// 32 bits wide and overflow while processing large amounts of data.
228+
self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64;
229+
self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64;
230+
231+
// reset these pointers so we don't accidentally read them later
232+
(*raw).next_in = ptr::null_mut();
233+
(*raw).avail_in = 0;
234+
(*raw).next_out = ptr::null_mut();
235+
(*raw).avail_out = 0;
236+
237+
match rc {
238+
MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()),
239+
MZ_OK => Ok(Status::Ok),
240+
MZ_BUF_ERROR => Ok(Status::BufError),
241+
MZ_STREAM_END => Ok(Status::StreamEnd),
242+
MZ_NEED_DICT => mem::decompress_need_dict((*raw).adler as u32),
243+
c => panic!("unknown return code: {}", c),
244+
}
242245
}
243246
}
244247

@@ -249,7 +252,7 @@ impl InflateBackend for Inflate {
249252
-MZ_DEFAULT_WINDOW_BITS
250253
};
251254
unsafe {
252-
inflateReset2(&mut *self.inner.stream_wrapper, bits);
255+
inflateReset2(self.inner.stream_wrapper.inner, bits);
253256
}
254257
self.inner.total_out = 0;
255258
self.inner.total_in = 0;
@@ -276,9 +279,9 @@ pub struct Deflate {
276279
impl DeflateBackend for Deflate {
277280
fn make(level: Compression, zlib_header: bool, window_bits: u8) -> Self {
278281
unsafe {
279-
let mut state = StreamWrapper::default();
282+
let state = StreamWrapper::default();
280283
let ret = mz_deflateInit2(
281-
&mut *state,
284+
state.inner,
282285
level.0 as c_int,
283286
MZ_DEFLATED,
284287
if zlib_header {
@@ -306,39 +309,46 @@ impl DeflateBackend for Deflate {
306309
output: &mut [u8],
307310
flush: FlushCompress,
308311
) -> Result<Status, CompressError> {
309-
let raw = &mut *self.inner.stream_wrapper;
310-
raw.msg = ptr::null_mut();
311-
raw.next_in = input.as_ptr() as *mut _;
312-
raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
313-
raw.next_out = output.as_mut_ptr();
314-
raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;
315-
316-
let rc = unsafe { mz_deflate(raw, flush as c_int) };
317-
318-
// Unfortunately the total counters provided by zlib might be only
319-
// 32 bits wide and overflow while processing large amounts of data.
320-
self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64;
321-
self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64;
322-
323-
// reset these pointers so we don't accidentally read them later
324-
raw.next_in = ptr::null_mut();
325-
raw.avail_in = 0;
326-
raw.next_out = ptr::null_mut();
327-
raw.avail_out = 0;
328-
329-
match rc {
330-
MZ_OK => Ok(Status::Ok),
331-
MZ_BUF_ERROR => Ok(Status::BufError),
332-
MZ_STREAM_END => Ok(Status::StreamEnd),
333-
MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()),
334-
c => panic!("unknown return code: {}", c),
312+
let raw = self.inner.stream_wrapper.inner;
313+
// We need to access the `inner` field of the `StreamWrapper` object
314+
// as a raw pointer here since the field `state` in `mz_stream` is
315+
// a pointer back to the `mz_stream` object. Any mutable borrow against
316+
// inner will become invalidated by `mz_deflate`, leading to an invalid
317+
// dereference after that function returns.
318+
unsafe {
319+
(*raw).msg = ptr::null_mut();
320+
(*raw).next_in = input.as_ptr() as *mut _;
321+
(*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
322+
(*raw).next_out = output.as_mut_ptr();
323+
(*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;
324+
325+
let rc = mz_deflate(raw, flush as c_int);
326+
327+
// Unfortunately the total counters provided by zlib might be only
328+
// 32 bits wide and overflow while processing large amounts of data.
329+
330+
self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64;
331+
self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64;
332+
// reset these pointers so we don't accidentally read them later
333+
(*raw).next_in = ptr::null_mut();
334+
(*raw).avail_in = 0;
335+
(*raw).next_out = ptr::null_mut();
336+
(*raw).avail_out = 0;
337+
338+
match rc {
339+
MZ_OK => Ok(Status::Ok),
340+
MZ_BUF_ERROR => Ok(Status::BufError),
341+
MZ_STREAM_END => Ok(Status::StreamEnd),
342+
MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()),
343+
c => panic!("unknown return code: {}", c),
344+
}
335345
}
336346
}
337347

338348
fn reset(&mut self) {
339349
self.inner.total_in = 0;
340350
self.inner.total_out = 0;
341-
let rc = unsafe { mz_deflateReset(&mut *self.inner.stream_wrapper) };
351+
let rc = unsafe { mz_deflateReset(self.inner.stream_wrapper.inner) };
342352
assert_eq!(rc, MZ_OK);
343353
}
344354
}

src/mem.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,16 @@ impl Compress {
265265
/// Returns the Adler-32 checksum of the dictionary.
266266
#[cfg(feature = "any_zlib")]
267267
pub fn set_dictionary(&mut self, dictionary: &[u8]) -> Result<u32, CompressError> {
268-
let stream = &mut *self.inner.inner.stream_wrapper;
269-
stream.msg = std::ptr::null_mut();
268+
let stream = self.inner.inner.stream_wrapper.inner;
270269
let rc = unsafe {
270+
(*stream).msg = std::ptr::null_mut();
271271
assert!(dictionary.len() < ffi::uInt::MAX as usize);
272272
ffi::deflateSetDictionary(stream, dictionary.as_ptr(), dictionary.len() as ffi::uInt)
273273
};
274274

275275
match rc {
276276
ffi::MZ_STREAM_ERROR => compress_failed(self.inner.inner.msg()),
277-
ffi::MZ_OK => Ok(stream.adler as u32),
277+
ffi::MZ_OK => Ok(unsafe { (*stream).adler } as u32),
278278
c => panic!("unknown return code: {}", c),
279279
}
280280
}
@@ -299,9 +299,10 @@ impl Compress {
299299
#[cfg(feature = "any_zlib")]
300300
pub fn set_level(&mut self, level: Compression) -> Result<(), CompressError> {
301301
use std::os::raw::c_int;
302-
let stream = &mut *self.inner.inner.stream_wrapper;
303-
stream.msg = std::ptr::null_mut();
304-
302+
let stream = self.inner.inner.stream_wrapper.inner;
303+
unsafe {
304+
(*stream).msg = std::ptr::null_mut();
305+
}
305306
let rc = unsafe { ffi::deflateParams(stream, level.0 as c_int, ffi::MZ_DEFAULT_STRATEGY) };
306307

307308
match rc {
@@ -476,17 +477,17 @@ impl Decompress {
476477
/// Specifies the decompression dictionary to use.
477478
#[cfg(feature = "any_zlib")]
478479
pub fn set_dictionary(&mut self, dictionary: &[u8]) -> Result<u32, DecompressError> {
479-
let stream = &mut *self.inner.inner.stream_wrapper;
480-
stream.msg = std::ptr::null_mut();
480+
let stream = self.inner.inner.stream_wrapper.inner;
481481
let rc = unsafe {
482+
(*stream).msg = std::ptr::null_mut();
482483
assert!(dictionary.len() < ffi::uInt::MAX as usize);
483484
ffi::inflateSetDictionary(stream, dictionary.as_ptr(), dictionary.len() as ffi::uInt)
484485
};
485486

486487
match rc {
487488
ffi::MZ_STREAM_ERROR => decompress_failed(self.inner.inner.msg()),
488-
ffi::MZ_DATA_ERROR => decompress_need_dict(stream.adler as u32),
489-
ffi::MZ_OK => Ok(stream.adler as u32),
489+
ffi::MZ_DATA_ERROR => decompress_need_dict(unsafe { (*stream).adler } as u32),
490+
ffi::MZ_OK => Ok(unsafe { (*stream).adler } as u32),
490491
c => panic!("unknown return code: {}", c),
491492
}
492493
}

0 commit comments

Comments
 (0)