Skip to content

Commit c39e67a

Browse files
committed
refactor: modularize file processors with DataFrame backend support
1 parent a4607b8 commit c39e67a

File tree

8 files changed

+638
-424
lines changed

8 files changed

+638
-424
lines changed

gokart/file_processor.py

Lines changed: 0 additions & 304 deletions
Original file line numberDiff line numberDiff line change
@@ -1,304 +0,0 @@
1-
from __future__ import annotations
2-
3-
import os
4-
import xml.etree.ElementTree as ET
5-
from abc import abstractmethod
6-
from io import BytesIO
7-
from logging import getLogger
8-
9-
import dill
10-
import luigi
11-
import luigi.contrib.s3
12-
import luigi.format
13-
import numpy as np
14-
import pandas as pd
15-
import pandas.errors
16-
from luigi.format import TextFormat
17-
18-
from gokart.object_storage import ObjectStorage
19-
from gokart.utils import load_dill_with_pandas_backward_compatibility
20-
21-
logger = getLogger(__name__)
22-
23-
24-
class FileProcessor:
25-
@abstractmethod
26-
def format(self):
27-
pass
28-
29-
@abstractmethod
30-
def load(self, file):
31-
pass
32-
33-
@abstractmethod
34-
def dump(self, obj, file):
35-
pass
36-
37-
38-
class BinaryFileProcessor(FileProcessor):
39-
"""
40-
Pass bytes to this processor
41-
42-
```
43-
figure_binary = io.BytesIO()
44-
plt.savefig(figure_binary)
45-
figure_binary.seek(0)
46-
BinaryFileProcessor().dump(figure_binary.read())
47-
```
48-
"""
49-
50-
def format(self):
51-
return luigi.format.Nop
52-
53-
def load(self, file):
54-
return file.read()
55-
56-
def dump(self, obj, file):
57-
file.write(obj)
58-
59-
60-
class _ChunkedLargeFileReader:
61-
def __init__(self, file) -> None:
62-
self._file = file
63-
64-
def __getattr__(self, item):
65-
return getattr(self._file, item)
66-
67-
def read(self, n):
68-
if n >= (1 << 31):
69-
logger.info(f'reading a large file with total_bytes={n}.')
70-
buffer = bytearray(n)
71-
idx = 0
72-
while idx < n:
73-
batch_size = min(n - idx, 1 << 31 - 1)
74-
logger.info(f'reading bytes [{idx}, {idx + batch_size})...')
75-
buffer[idx : idx + batch_size] = self._file.read(batch_size)
76-
idx += batch_size
77-
logger.info('done.')
78-
return buffer
79-
return self._file.read(n)
80-
81-
82-
class PickleFileProcessor(FileProcessor):
83-
def format(self):
84-
return luigi.format.Nop
85-
86-
def load(self, file):
87-
if not file.seekable():
88-
# load_dill_with_pandas_backward_compatibility() requires file with seek() and readlines() implemented.
89-
# Therefore, we need to wrap with BytesIO which makes file seekable and readlinesable.
90-
# For example, ReadableS3File is not a seekable file.
91-
return load_dill_with_pandas_backward_compatibility(BytesIO(file.read()))
92-
return load_dill_with_pandas_backward_compatibility(_ChunkedLargeFileReader(file))
93-
94-
def dump(self, obj, file):
95-
self._write(dill.dumps(obj, protocol=4), file)
96-
97-
@staticmethod
98-
def _write(buffer, file):
99-
n = len(buffer)
100-
idx = 0
101-
while idx < n:
102-
logger.info(f'writing a file with total_bytes={n}...')
103-
batch_size = min(n - idx, 1 << 31 - 1)
104-
logger.info(f'writing bytes [{idx}, {idx + batch_size})')
105-
file.write(buffer[idx : idx + batch_size])
106-
idx += batch_size
107-
logger.info('done')
108-
109-
110-
class TextFileProcessor(FileProcessor):
111-
def format(self):
112-
return None
113-
114-
def load(self, file):
115-
return [s.rstrip() for s in file.readlines()]
116-
117-
def dump(self, obj, file):
118-
if isinstance(obj, list):
119-
for x in obj:
120-
file.write(str(x) + '\n')
121-
else:
122-
file.write(str(obj))
123-
124-
125-
class CsvFileProcessor(FileProcessor):
126-
def __init__(self, sep=',', encoding: str = 'utf-8'):
127-
self._sep = sep
128-
self._encoding = encoding
129-
super().__init__()
130-
131-
def format(self):
132-
return TextFormat(encoding=self._encoding)
133-
134-
def load(self, file):
135-
try:
136-
return pd.read_csv(file, sep=self._sep, encoding=self._encoding)
137-
except pd.errors.EmptyDataError:
138-
return pd.DataFrame()
139-
140-
def dump(self, obj, file):
141-
assert isinstance(obj, pd.DataFrame | pd.Series), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.'
142-
obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding)
143-
144-
145-
class GzipFileProcessor(FileProcessor):
146-
def format(self):
147-
return luigi.format.Gzip
148-
149-
def load(self, file):
150-
return [s.rstrip().decode() for s in file.readlines()]
151-
152-
def dump(self, obj, file):
153-
if isinstance(obj, list):
154-
for x in obj:
155-
file.write((str(x) + '\n').encode())
156-
else:
157-
file.write(str(obj).encode())
158-
159-
160-
class JsonFileProcessor(FileProcessor):
161-
def __init__(self, orient: str | None = None):
162-
self._orient = orient
163-
164-
def format(self):
165-
return luigi.format.Nop
166-
167-
def load(self, file):
168-
try:
169-
return pd.read_json(file, orient=self._orient, lines=True if self._orient == 'records' else False)
170-
except pd.errors.EmptyDataError:
171-
return pd.DataFrame()
172-
173-
def dump(self, obj, file):
174-
assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), (
175-
f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.'
176-
)
177-
if isinstance(obj, dict):
178-
obj = pd.DataFrame.from_dict(obj)
179-
obj.to_json(file, orient=self._orient, lines=True if self._orient == 'records' else False)
180-
181-
182-
class XmlFileProcessor(FileProcessor):
183-
def format(self):
184-
return None
185-
186-
def load(self, file):
187-
try:
188-
return ET.parse(file)
189-
except ET.ParseError:
190-
return ET.ElementTree()
191-
192-
def dump(self, obj, file):
193-
assert isinstance(obj, ET.ElementTree), f'requires ET.ElementTree, but {type(obj)} is passed.'
194-
obj.write(file)
195-
196-
197-
class NpzFileProcessor(FileProcessor):
198-
def format(self):
199-
return luigi.format.Nop
200-
201-
def load(self, file):
202-
return np.load(file)['data']
203-
204-
def dump(self, obj, file):
205-
assert isinstance(obj, np.ndarray), f'requires np.ndarray, but {type(obj)} is passed.'
206-
np.savez_compressed(file, data=obj)
207-
208-
209-
class ParquetFileProcessor(FileProcessor):
210-
def __init__(self, engine='pyarrow', compression=None):
211-
self._engine = engine
212-
self._compression = compression
213-
super().__init__()
214-
215-
def format(self):
216-
return luigi.format.Nop
217-
218-
def load(self, file):
219-
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
220-
# pandas.read_parquet accepts file-like object
221-
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
222-
# which is needed for pandas to read a file in chunks.
223-
if ObjectStorage.is_buffered_reader(file):
224-
return pd.read_parquet(file.name)
225-
else:
226-
return pd.read_parquet(BytesIO(file.read()))
227-
228-
def dump(self, obj, file):
229-
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
230-
# MEMO: to_parquet only supports a filepath as string (not a file handle)
231-
obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression)
232-
233-
234-
class FeatherFileProcessor(FileProcessor):
235-
def __init__(self, store_index_in_feather: bool):
236-
super().__init__()
237-
self._store_index_in_feather = store_index_in_feather
238-
self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__'
239-
240-
def format(self):
241-
return luigi.format.Nop
242-
243-
def load(self, file):
244-
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
245-
# pandas.read_feather accepts file-like object
246-
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
247-
# which is needed for pandas to read a file in chunks.
248-
if ObjectStorage.is_buffered_reader(file):
249-
loaded_df = pd.read_feather(file.name)
250-
else:
251-
loaded_df = pd.read_feather(BytesIO(file.read()))
252-
253-
if self._store_index_in_feather:
254-
if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns):
255-
index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX]
256-
index_column = index_columns[0]
257-
index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :]
258-
if index_name == 'None':
259-
index_name = None
260-
loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name)
261-
loaded_df = loaded_df.drop(columns={index_column})
262-
263-
return loaded_df
264-
265-
def dump(self, obj, file):
266-
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
267-
dump_obj = obj.copy()
268-
269-
if self._store_index_in_feather:
270-
index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}'
271-
assert index_column_name not in dump_obj.columns, (
272-
f'column name {index_column_name} already exists in dump_obj. \
273-
Consider not saving index by setting store_index_in_feather=False.'
274-
)
275-
assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.'
276-
277-
dump_obj[index_column_name] = dump_obj.index
278-
dump_obj = dump_obj.reset_index(drop=True)
279-
280-
# to_feather supports "binary" file-like object, but file variable is text
281-
dump_obj.to_feather(file.name)
282-
283-
284-
def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor:
285-
extension2processor = {
286-
'.txt': TextFileProcessor(),
287-
'.ini': TextFileProcessor(),
288-
'.csv': CsvFileProcessor(sep=','),
289-
'.tsv': CsvFileProcessor(sep='\t'),
290-
'.pkl': PickleFileProcessor(),
291-
'.gz': GzipFileProcessor(),
292-
'.json': JsonFileProcessor(),
293-
'.ndjson': JsonFileProcessor(orient='records'),
294-
'.xml': XmlFileProcessor(),
295-
'.npz': NpzFileProcessor(),
296-
'.parquet': ParquetFileProcessor(compression='gzip'),
297-
'.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather),
298-
'.png': BinaryFileProcessor(),
299-
'.jpg': BinaryFileProcessor(),
300-
}
301-
302-
extension = os.path.splitext(file_path)[1]
303-
assert extension in extension2processor, f'{extension} is not supported. The supported extensions are {list(extension2processor.keys())}.'
304-
return extension2processor[extension]

0 commit comments

Comments
 (0)