2424#include  < zstd.h> 
2525
2626#include  " arrow/status.h" 
27+ #include  " arrow/util/logging.h" 
2728#include  " arrow/util/macros.h" 
2829
2930using  std::size_t ;
@@ -34,6 +35,12 @@ namespace util {
3435//  XXX level = 1 probably doesn't compress very much
3536constexpr  int  kZSTDDefaultCompressionLevel  = 1 ;
3637
38+ static  Status ZSTDError (size_t  ret, const  char * prefix_msg) {
39+   std::stringstream ss;
40+   ss << prefix_msg << ZSTD_getErrorName (ret);
41+   return  Status::IOError (ss.str ());
42+ }
43+ 
3744//  ----------------------------------------------------------------------
3845//  ZSTD decompressor implementation
3946
@@ -47,7 +54,7 @@ class ZSTDDecompressor : public Decompressor {
4754    finished_ = false ;
4855    size_t  ret = ZSTD_initDStream (stream_);
4956    if  (ZSTD_isError (ret)) {
50-       return  ZSTDError (ret, " zstd  init failed: " 
57+       return  ZSTDError (ret, " ZSTD  init failed: " 
5158    } else  {
5259      return  Status::OK ();
5360    }
@@ -69,7 +76,7 @@ class ZSTDDecompressor : public Decompressor {
6976    size_t  ret;
7077    ret = ZSTD_decompressStream (stream_, &out_buf, &in_buf);
7178    if  (ZSTD_isError (ret)) {
72-       return  ZSTDError (ret, " zstd  decompress failed: " 
79+       return  ZSTDError (ret, " ZSTD  decompress failed: " 
7380    }
7481    *bytes_read = static_cast <int64_t >(in_buf.pos );
7582    *bytes_written = static_cast <int64_t >(out_buf.pos );
@@ -81,12 +88,6 @@ class ZSTDDecompressor : public Decompressor {
8188  bool  IsFinished () override  { return  finished_; }
8289
8390 protected: 
84-   Status ZSTDError (size_t  ret, const  char * prefix_msg) {
85-     std::stringstream ss;
86-     ss << prefix_msg << ZSTD_getErrorName (ret);
87-     return  Status::IOError (ss.str ());
88-   }
89- 
9091  ZSTD_DStream* stream_;
9192  bool  finished_;
9293};
@@ -103,7 +104,7 @@ class ZSTDCompressor : public Compressor {
103104  Status Init () {
104105    size_t  ret = ZSTD_initCStream (stream_, kZSTDDefaultCompressionLevel );
105106    if  (ZSTD_isError (ret)) {
106-       return  ZSTDError (ret, " zstd  init failed: " 
107+       return  ZSTDError (ret, " ZSTD  init failed: " 
107108    } else  {
108109      return  Status::OK ();
109110    }
@@ -119,12 +120,6 @@ class ZSTDCompressor : public Compressor {
119120             bool * should_retry) override ;
120121
121122 protected: 
122-   Status ZSTDError (size_t  ret, const  char * prefix_msg) {
123-     std::stringstream ss;
124-     ss << prefix_msg << ZSTD_getErrorName (ret);
125-     return  Status::IOError (ss.str ());
126-   }
127- 
128123  ZSTD_CStream* stream_;
129124};
130125
@@ -144,7 +139,7 @@ Status ZSTDCompressor::Compress(int64_t input_len, const uint8_t* input,
144139  size_t  ret;
145140  ret = ZSTD_compressStream (stream_, &out_buf, &in_buf);
146141  if  (ZSTD_isError (ret)) {
147-     return  ZSTDError (ret, " zstd  compress failed: " 
142+     return  ZSTDError (ret, " ZSTD  compress failed: " 
148143  }
149144  *bytes_read = static_cast <int64_t >(in_buf.pos );
150145  *bytes_written = static_cast <int64_t >(out_buf.pos );
@@ -162,7 +157,7 @@ Status ZSTDCompressor::Flush(int64_t output_len, uint8_t* output, int64_t* bytes
162157  size_t  ret;
163158  ret = ZSTD_flushStream (stream_, &out_buf);
164159  if  (ZSTD_isError (ret)) {
165-     return  ZSTDError (ret, " zstd  flush failed: " 
160+     return  ZSTDError (ret, " ZSTD  flush failed: " 
166161  }
167162  *bytes_written = static_cast <int64_t >(out_buf.pos );
168163  *should_retry = ret > 0 ;
@@ -180,7 +175,7 @@ Status ZSTDCompressor::End(int64_t output_len, uint8_t* output, int64_t* bytes_w
180175  size_t  ret;
181176  ret = ZSTD_endStream (stream_, &out_buf);
182177  if  (ZSTD_isError (ret)) {
183-     return  ZSTDError (ret, " zstd  end failed: " 
178+     return  ZSTDError (ret, " ZSTD  end failed: " 
184179  }
185180  *bytes_written = static_cast <int64_t >(out_buf.pos );
186181  *should_retry = ret > 0 ;
@@ -206,10 +201,20 @@ Status ZSTDCodec::MakeDecompressor(std::shared_ptr<Decompressor>* out) {
206201
207202Status ZSTDCodec::Decompress (int64_t  input_len, const  uint8_t * input, int64_t  output_len,
208203                             uint8_t * output_buffer) {
209-   int64_t  decompressed_size =
210-       ZSTD_decompress (output_buffer, static_cast <size_t >(output_len), input,
211-                       static_cast <size_t >(input_len));
212-   if  (decompressed_size != output_len) {
204+   if  (output_buffer == nullptr ) {
205+     //  We may pass a NULL 0-byte output buffer but some zstd versions demand
206+     //  a valid pointer: https://github.com/facebook/zstd/issues/1385
207+     static  uint8_t  empty_buffer[1 ];
208+     DCHECK_EQ (output_len, 0 );
209+     output_buffer = empty_buffer;
210+   }
211+ 
212+   size_t  ret = ZSTD_decompress (output_buffer, static_cast <size_t >(output_len), input,
213+                                static_cast <size_t >(input_len));
214+   if  (ZSTD_isError (ret)) {
215+     return  ZSTDError (ret, " ZSTD decompression failed: " 
216+   }
217+   if  (static_cast <int64_t >(ret) != output_len) {
213218    return  Status::IOError (" Corrupt ZSTD compressed data." 
214219  }
215220  return  Status::OK ();
@@ -223,12 +228,13 @@ int64_t ZSTDCodec::MaxCompressedLen(int64_t input_len,
223228Status ZSTDCodec::Compress (int64_t  input_len, const  uint8_t * input,
224229                           int64_t  output_buffer_len, uint8_t * output_buffer,
225230                           int64_t * output_length) {
226-   *output_length  =
231+   size_t  ret  =
227232      ZSTD_compress (output_buffer, static_cast <size_t >(output_buffer_len), input,
228233                    static_cast <size_t >(input_len), kZSTDDefaultCompressionLevel );
229-   if  (ZSTD_isError (*output_length )) {
230-     return  Status::IOError ( " ZSTD compression failure. " 
234+   if  (ZSTD_isError (ret )) {
235+     return  ZSTDError (ret,  " ZSTD compression failed:  " 
231236  }
237+   *output_length = static_cast <int64_t >(ret);
232238  return  Status::OK ();
233239}
234240
0 commit comments