|
1 | | -from __future__ import annotations |
2 | | - |
3 | | -import os |
4 | | -import xml.etree.ElementTree as ET |
5 | | -from abc import abstractmethod |
6 | | -from io import BytesIO |
7 | | -from logging import getLogger |
8 | | - |
9 | | -import dill |
10 | | -import luigi |
11 | | -import luigi.contrib.s3 |
12 | | -import luigi.format |
13 | | -import numpy as np |
14 | | -import pandas as pd |
15 | | -import pandas.errors |
16 | | -from luigi.format import TextFormat |
17 | | - |
18 | | -from gokart.object_storage import ObjectStorage |
19 | | -from gokart.utils import load_dill_with_pandas_backward_compatibility |
20 | | - |
21 | | -logger = getLogger(__name__) |
22 | | - |
23 | | - |
24 | | -class FileProcessor: |
25 | | - @abstractmethod |
26 | | - def format(self): |
27 | | - pass |
28 | | - |
29 | | - @abstractmethod |
30 | | - def load(self, file): |
31 | | - pass |
32 | | - |
33 | | - @abstractmethod |
34 | | - def dump(self, obj, file): |
35 | | - pass |
36 | | - |
37 | | - |
38 | | -class BinaryFileProcessor(FileProcessor): |
39 | | - """ |
40 | | - Pass bytes to this processor |
41 | | -
|
42 | | - ``` |
43 | | - figure_binary = io.BytesIO() |
44 | | - plt.savefig(figure_binary) |
45 | | - figure_binary.seek(0) |
46 | | - BinaryFileProcessor().dump(figure_binary.read()) |
47 | | - ``` |
48 | | - """ |
49 | | - |
50 | | - def format(self): |
51 | | - return luigi.format.Nop |
52 | | - |
53 | | - def load(self, file): |
54 | | - return file.read() |
55 | | - |
56 | | - def dump(self, obj, file): |
57 | | - file.write(obj) |
58 | | - |
59 | | - |
60 | | -class _ChunkedLargeFileReader: |
61 | | - def __init__(self, file) -> None: |
62 | | - self._file = file |
63 | | - |
64 | | - def __getattr__(self, item): |
65 | | - return getattr(self._file, item) |
66 | | - |
67 | | - def read(self, n): |
68 | | - if n >= (1 << 31): |
69 | | - logger.info(f'reading a large file with total_bytes={n}.') |
70 | | - buffer = bytearray(n) |
71 | | - idx = 0 |
72 | | - while idx < n: |
73 | | - batch_size = min(n - idx, 1 << 31 - 1) |
74 | | - logger.info(f'reading bytes [{idx}, {idx + batch_size})...') |
75 | | - buffer[idx : idx + batch_size] = self._file.read(batch_size) |
76 | | - idx += batch_size |
77 | | - logger.info('done.') |
78 | | - return buffer |
79 | | - return self._file.read(n) |
80 | | - |
81 | | - |
82 | | -class PickleFileProcessor(FileProcessor): |
83 | | - def format(self): |
84 | | - return luigi.format.Nop |
85 | | - |
86 | | - def load(self, file): |
87 | | - if not file.seekable(): |
88 | | - # load_dill_with_pandas_backward_compatibility() requires file with seek() and readlines() implemented. |
89 | | - # Therefore, we need to wrap with BytesIO which makes file seekable and readlinesable. |
90 | | - # For example, ReadableS3File is not a seekable file. |
91 | | - return load_dill_with_pandas_backward_compatibility(BytesIO(file.read())) |
92 | | - return load_dill_with_pandas_backward_compatibility(_ChunkedLargeFileReader(file)) |
93 | | - |
94 | | - def dump(self, obj, file): |
95 | | - self._write(dill.dumps(obj, protocol=4), file) |
96 | | - |
97 | | - @staticmethod |
98 | | - def _write(buffer, file): |
99 | | - n = len(buffer) |
100 | | - idx = 0 |
101 | | - while idx < n: |
102 | | - logger.info(f'writing a file with total_bytes={n}...') |
103 | | - batch_size = min(n - idx, 1 << 31 - 1) |
104 | | - logger.info(f'writing bytes [{idx}, {idx + batch_size})') |
105 | | - file.write(buffer[idx : idx + batch_size]) |
106 | | - idx += batch_size |
107 | | - logger.info('done') |
108 | | - |
109 | | - |
110 | | -class TextFileProcessor(FileProcessor): |
111 | | - def format(self): |
112 | | - return None |
113 | | - |
114 | | - def load(self, file): |
115 | | - return [s.rstrip() for s in file.readlines()] |
116 | | - |
117 | | - def dump(self, obj, file): |
118 | | - if isinstance(obj, list): |
119 | | - for x in obj: |
120 | | - file.write(str(x) + '\n') |
121 | | - else: |
122 | | - file.write(str(obj)) |
123 | | - |
124 | | - |
125 | | -class CsvFileProcessor(FileProcessor): |
126 | | - def __init__(self, sep=',', encoding: str = 'utf-8'): |
127 | | - self._sep = sep |
128 | | - self._encoding = encoding |
129 | | - super().__init__() |
130 | | - |
131 | | - def format(self): |
132 | | - return TextFormat(encoding=self._encoding) |
133 | | - |
134 | | - def load(self, file): |
135 | | - try: |
136 | | - return pd.read_csv(file, sep=self._sep, encoding=self._encoding) |
137 | | - except pd.errors.EmptyDataError: |
138 | | - return pd.DataFrame() |
139 | | - |
140 | | - def dump(self, obj, file): |
141 | | - assert isinstance(obj, pd.DataFrame | pd.Series), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.' |
142 | | - obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding) |
143 | | - |
144 | | - |
145 | | -class GzipFileProcessor(FileProcessor): |
146 | | - def format(self): |
147 | | - return luigi.format.Gzip |
148 | | - |
149 | | - def load(self, file): |
150 | | - return [s.rstrip().decode() for s in file.readlines()] |
151 | | - |
152 | | - def dump(self, obj, file): |
153 | | - if isinstance(obj, list): |
154 | | - for x in obj: |
155 | | - file.write((str(x) + '\n').encode()) |
156 | | - else: |
157 | | - file.write(str(obj).encode()) |
158 | | - |
159 | | - |
160 | | -class JsonFileProcessor(FileProcessor): |
161 | | - def __init__(self, orient: str | None = None): |
162 | | - self._orient = orient |
163 | | - |
164 | | - def format(self): |
165 | | - return luigi.format.Nop |
166 | | - |
167 | | - def load(self, file): |
168 | | - try: |
169 | | - return pd.read_json(file, orient=self._orient, lines=True if self._orient == 'records' else False) |
170 | | - except pd.errors.EmptyDataError: |
171 | | - return pd.DataFrame() |
172 | | - |
173 | | - def dump(self, obj, file): |
174 | | - assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), ( |
175 | | - f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' |
176 | | - ) |
177 | | - if isinstance(obj, dict): |
178 | | - obj = pd.DataFrame.from_dict(obj) |
179 | | - obj.to_json(file, orient=self._orient, lines=True if self._orient == 'records' else False) |
180 | | - |
181 | | - |
182 | | -class XmlFileProcessor(FileProcessor): |
183 | | - def format(self): |
184 | | - return None |
185 | | - |
186 | | - def load(self, file): |
187 | | - try: |
188 | | - return ET.parse(file) |
189 | | - except ET.ParseError: |
190 | | - return ET.ElementTree() |
191 | | - |
192 | | - def dump(self, obj, file): |
193 | | - assert isinstance(obj, ET.ElementTree), f'requires ET.ElementTree, but {type(obj)} is passed.' |
194 | | - obj.write(file) |
195 | | - |
196 | | - |
197 | | -class NpzFileProcessor(FileProcessor): |
198 | | - def format(self): |
199 | | - return luigi.format.Nop |
200 | | - |
201 | | - def load(self, file): |
202 | | - return np.load(file)['data'] |
203 | | - |
204 | | - def dump(self, obj, file): |
205 | | - assert isinstance(obj, np.ndarray), f'requires np.ndarray, but {type(obj)} is passed.' |
206 | | - np.savez_compressed(file, data=obj) |
207 | | - |
208 | | - |
209 | | -class ParquetFileProcessor(FileProcessor): |
210 | | - def __init__(self, engine='pyarrow', compression=None): |
211 | | - self._engine = engine |
212 | | - self._compression = compression |
213 | | - super().__init__() |
214 | | - |
215 | | - def format(self): |
216 | | - return luigi.format.Nop |
217 | | - |
218 | | - def load(self, file): |
219 | | - # FIXME(mamo3gr): enable streaming (chunked) read with S3. |
220 | | - # pandas.read_parquet accepts file-like object |
221 | | - # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, |
222 | | - # which is needed for pandas to read a file in chunks. |
223 | | - if ObjectStorage.is_buffered_reader(file): |
224 | | - return pd.read_parquet(file.name) |
225 | | - else: |
226 | | - return pd.read_parquet(BytesIO(file.read())) |
227 | | - |
228 | | - def dump(self, obj, file): |
229 | | - assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' |
230 | | - # MEMO: to_parquet only supports a filepath as string (not a file handle) |
231 | | - obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression) |
232 | | - |
233 | | - |
234 | | -class FeatherFileProcessor(FileProcessor): |
235 | | - def __init__(self, store_index_in_feather: bool): |
236 | | - super().__init__() |
237 | | - self._store_index_in_feather = store_index_in_feather |
238 | | - self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__' |
239 | | - |
240 | | - def format(self): |
241 | | - return luigi.format.Nop |
242 | | - |
243 | | - def load(self, file): |
244 | | - # FIXME(mamo3gr): enable streaming (chunked) read with S3. |
245 | | - # pandas.read_feather accepts file-like object |
246 | | - # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, |
247 | | - # which is needed for pandas to read a file in chunks. |
248 | | - if ObjectStorage.is_buffered_reader(file): |
249 | | - loaded_df = pd.read_feather(file.name) |
250 | | - else: |
251 | | - loaded_df = pd.read_feather(BytesIO(file.read())) |
252 | | - |
253 | | - if self._store_index_in_feather: |
254 | | - if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns): |
255 | | - index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX] |
256 | | - index_column = index_columns[0] |
257 | | - index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :] |
258 | | - if index_name == 'None': |
259 | | - index_name = None |
260 | | - loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name) |
261 | | - loaded_df = loaded_df.drop(columns={index_column}) |
262 | | - |
263 | | - return loaded_df |
264 | | - |
265 | | - def dump(self, obj, file): |
266 | | - assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' |
267 | | - dump_obj = obj.copy() |
268 | | - |
269 | | - if self._store_index_in_feather: |
270 | | - index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' |
271 | | - assert index_column_name not in dump_obj.columns, ( |
272 | | - f'column name {index_column_name} already exists in dump_obj. \ |
273 | | - Consider not saving index by setting store_index_in_feather=False.' |
274 | | - ) |
275 | | - assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' |
276 | | - |
277 | | - dump_obj[index_column_name] = dump_obj.index |
278 | | - dump_obj = dump_obj.reset_index(drop=True) |
279 | | - |
280 | | - # to_feather supports "binary" file-like object, but file variable is text |
281 | | - dump_obj.to_feather(file.name) |
282 | | - |
283 | | - |
284 | | -def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor: |
285 | | - extension2processor = { |
286 | | - '.txt': TextFileProcessor(), |
287 | | - '.ini': TextFileProcessor(), |
288 | | - '.csv': CsvFileProcessor(sep=','), |
289 | | - '.tsv': CsvFileProcessor(sep='\t'), |
290 | | - '.pkl': PickleFileProcessor(), |
291 | | - '.gz': GzipFileProcessor(), |
292 | | - '.json': JsonFileProcessor(), |
293 | | - '.ndjson': JsonFileProcessor(orient='records'), |
294 | | - '.xml': XmlFileProcessor(), |
295 | | - '.npz': NpzFileProcessor(), |
296 | | - '.parquet': ParquetFileProcessor(compression='gzip'), |
297 | | - '.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather), |
298 | | - '.png': BinaryFileProcessor(), |
299 | | - '.jpg': BinaryFileProcessor(), |
300 | | - } |
301 | | - |
302 | | - extension = os.path.splitext(file_path)[1] |
303 | | - assert extension in extension2processor, f'{extension} is not supported. The supported extensions are {list(extension2processor.keys())}.' |
304 | | - return extension2processor[extension] |
0 commit comments