@@ -56,6 +56,9 @@ struct header_info
56
56
57
57
/* * A vector of values indicating the shape of each dimension of the tensor. */
58
58
std::vector<size_t > shape;
59
+
60
+ /* * Value used to indicate the maximum length of an element (used by Unicode strings) */
61
+ std::size_t max_element_length;
59
62
};
60
63
61
64
/* * Writes an NPY header to the provided stream.
@@ -110,6 +113,25 @@ void write_npy_header(std::basic_ostream<CHAR> &output,
110
113
output.write (reinterpret_cast <const CHAR *>(end.data ()), end.length ());
111
114
}
112
115
116
+ template <typename T, typename CHAR>
117
+ void copy_to (const T* data_ptr, std::size_t num_elements, std::basic_ostream<CHAR>& output, npy::endian_t endianness)
118
+ {
119
+ if (endianness == npy::endian_t ::NATIVE || endianness == native_endian ())
120
+ {
121
+ output.write (reinterpret_cast <const CHAR *>(data_ptr), num_elements * sizeof (T));
122
+ }
123
+ else
124
+ {
125
+ CHAR buffer[sizeof (T)];
126
+ for (auto curr = data_ptr; curr < data_ptr + num_elements; ++curr)
127
+ {
128
+ const CHAR *start = reinterpret_cast <const CHAR *>(curr);
129
+ std::reverse_copy (start, start + sizeof (T), buffer);
130
+ output.write (buffer, sizeof (T));
131
+ }
132
+ }
133
+ }
134
+
113
135
/* * Saves a tensor to the provided stream.
114
136
* \tparam T the data type
115
137
* \tparam TENSOR the tensor type.
@@ -120,32 +142,72 @@ void write_npy_header(std::basic_ostream<CHAR> &output,
120
142
*/
121
143
template <typename T,
122
144
template <typename > class TENSOR ,
123
- typename CHAR>
145
+ typename CHAR,
146
+ std::enable_if_t <!std::is_same<std::wstring, T>::value, int > = 42 >
124
147
void save (std::basic_ostream<CHAR> &output,
125
148
const TENSOR<T> &tensor,
126
149
endian_t endianness = npy::endian_t ::NATIVE)
127
150
{
128
151
auto dtype = to_dtype (tensor.dtype (), endianness);
129
152
write_npy_header (output, dtype, tensor.fortran_order (), tensor.shape ());
153
+ copy_to (tensor.data (), tensor.size (), output, endianness);
154
+ };
130
155
131
- if (endianness == npy::endian_t ::NATIVE ||
132
- endianness == native_endian () ||
133
- dtype[0 ] == ' |' )
156
+ /* * Saves a unicode string tensor to the provided stream.
157
+ * \tparam TENSOR the tensor type.
158
+ * \param output the output stream
159
+ * \param tensor the tensor
160
+ * \param endianness the endianness to use in saving the tensor
161
+ * \sa npy::tensor
162
+ */
163
+ template <typename T,
164
+ template <typename > class TENSOR ,
165
+ typename CHAR,
166
+ std::enable_if_t <std::is_same<std::wstring, T>::value, int > = 42 >
167
+ void save (std::basic_ostream<CHAR> &output,
168
+ const TENSOR<std::wstring> &tensor,
169
+ endian_t endianness = npy::endian_t ::NATIVE)
170
+ {
171
+ std::size_t max_length = 0 ;
172
+ for (const auto & element : tensor)
134
173
{
135
- output.write (reinterpret_cast <const CHAR *>(tensor.data ()), tensor.size () * sizeof (T));
174
+ if (element.size () > max_length)
175
+ {
176
+ max_length = element.size ();
177
+ }
136
178
}
137
- else
179
+
180
+ if (endianness == npy::endian_t ::NATIVE)
138
181
{
139
- CHAR buffer[sizeof (T)];
140
- for (auto curr = tensor.data (); curr < tensor.data () + tensor.size (); ++curr)
182
+ endianness = native_endian ();
183
+ }
184
+
185
+ std::string dtype = " >U" + std::to_string (max_length);
186
+ if (endianness == npy::endian_t ::LITTLE)
187
+ {
188
+ dtype = " <U" + std::to_string (max_length);
189
+ }
190
+
191
+ write_npy_header (output, dtype, tensor.fortran_order (), tensor.shape ());
192
+
193
+ std::vector<std::int32_t > unicode (tensor.size () * max_length, 0 );
194
+ auto word_start = unicode.begin ();
195
+ for (const auto & element : tensor)
196
+ {
197
+ auto char_it = word_start;
198
+ for (const auto & wchar : element)
141
199
{
142
- const CHAR *start = reinterpret_cast <const CHAR *>(curr);
143
- std::reverse_copy (start, start + sizeof (T), buffer);
144
- output.write (buffer, sizeof (T));
200
+ *char_it = static_cast <std::int32_t >(wchar);
201
+ char_it += 1 ;
145
202
}
203
+
204
+ word_start += max_length;
146
205
}
206
+
207
+ copy_to (unicode.data (), unicode.size (), output, endianness);
147
208
};
148
209
210
+
149
211
/* * Saves a tensor to the provided location on disk.
150
212
* \tparam T the data type
151
213
* \tparam TENSOR the tensor type.
@@ -166,7 +228,7 @@ void save(const std::string &path,
166
228
throw std::invalid_argument (" path" );
167
229
}
168
230
169
- save (output, tensor, endianness);
231
+ save<T, TENSOR, char > (output, tensor, endianness);
170
232
};
171
233
172
234
/* * Read an NPY header from the provided stream.
@@ -202,6 +264,26 @@ header_info read_npy_header(std::basic_istream<CHAR> &input)
202
264
return header_info (dictionary);
203
265
}
204
266
267
+ template <typename T, typename CHAR>
268
+ void copy_to (std::basic_istream<CHAR> &input, T* data_ptr, std::size_t num_elements, npy::endian_t endianness)
269
+ {
270
+ if (endianness == npy::endian_t ::NATIVE || endianness == native_endian ())
271
+ {
272
+ CHAR *start = reinterpret_cast <CHAR *>(data_ptr);
273
+ input.read (start, num_elements * sizeof (T));
274
+ }
275
+ else
276
+ {
277
+ CHAR buffer[sizeof (T)];
278
+ for (auto curr = data_ptr; curr < data_ptr + num_elements; ++curr)
279
+ {
280
+ input.read (buffer, sizeof (T));
281
+ CHAR *start = reinterpret_cast <CHAR *>(curr);
282
+ std::reverse_copy (buffer, buffer + sizeof (T), start);
283
+ }
284
+ }
285
+ }
286
+
205
287
/* * Loads a tensor in NPY format from the provided stream. The type of the tensor
206
288
* must match the data to be read.
207
289
* \tparam T the data type
@@ -212,7 +294,8 @@ header_info read_npy_header(std::basic_istream<CHAR> &input)
212
294
*/
213
295
template <typename T,
214
296
template <typename > class TENSOR ,
215
- typename CHAR>
297
+ typename CHAR,
298
+ std::enable_if_t <!std::is_same<std::wstring, T>::value, int > = 42 >
216
299
TENSOR<T> load (std::basic_istream<CHAR> &input)
217
300
{
218
301
header_info info = read_npy_header (input);
@@ -222,20 +305,45 @@ TENSOR<T> load(std::basic_istream<CHAR> &input)
222
305
throw std::logic_error (" requested dtype does not match stream's dtype" );
223
306
}
224
307
225
- if (info.endianness == npy::endian_t ::NATIVE || info.endianness == native_endian ())
308
+ copy_to (input, tensor.data (), tensor.size (), info.endianness );
309
+ return tensor;
310
+ }
311
+
312
+
313
+ /* * Loads a unicode string tensor in NPY format from the provided stream. The type of the tensor
314
+ * must match the data to be read.
315
+ * \tparam T the data type
316
+ * \tparam TENSOR the tensor type
317
+ * \param input the input stream
318
+ * \return an object of type TENSOR<T> read from the stream
319
+ * \sa npy::tensor
320
+ */
321
+ template <typename T,
322
+ template <typename > class TENSOR ,
323
+ typename CHAR,
324
+ std::enable_if_t <std::is_same<std::wstring, T>::value, int > = 42 >
325
+ TENSOR<T> load (std::basic_istream<CHAR> &input)
326
+ {
327
+ header_info info = read_npy_header (input);
328
+ TENSOR<T> tensor (info.shape , info.fortran_order );
329
+ if (info.dtype != tensor.dtype ())
226
330
{
227
- CHAR *start = reinterpret_cast <CHAR *>(tensor.data ());
228
- input.read (start, tensor.size () * sizeof (T));
331
+ throw std::logic_error (" requested dtype does not match stream's dtype" );
229
332
}
230
- else
333
+
334
+ std::vector<std::int32_t > unicode (tensor.size () * info.max_element_length , 0 );
335
+ copy_to (input, unicode.data (), unicode.size (), info.endianness );
336
+
337
+ auto word_start = unicode.begin ();
338
+ for (auto & element : tensor)
231
339
{
232
- CHAR buffer[ sizeof (T)] ;
233
- for ( auto curr = tensor. data (); curr < tensor. data () + tensor. size () ; ++curr )
340
+ auto char_it = word_start ;
341
+ for (std:: size_t i= 0 ; i<info. max_element_length && *char_it > 0 ; ++i, ++char_it )
234
342
{
235
- input.read (buffer, sizeof (T));
236
- CHAR *start = reinterpret_cast <CHAR *>(curr);
237
- std::reverse_copy (buffer, buffer + sizeof (T), start);
343
+ element.push_back (static_cast <wchar_t >(*char_it));
238
344
}
345
+
346
+ word_start += info.max_element_length ;
239
347
}
240
348
241
349
return tensor;
0 commit comments