Skip to content

Commit e879df6

Browse files
committed
add batch process
1 parent cb3f2f4 commit e879df6

File tree

2 files changed

+76
-23
lines changed

2 files changed

+76
-23
lines changed

snippets/utils.py

+67-19
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
import os
1616
import pickle
1717
import re
18+
import shutil
1819
import subprocess
1920
import time
2021
from datetime import datetime
2122
import pandas as pd
22-
from typing import Any, Dict, Iterable, List, Sequence, Tuple, _GenericAlias, Union
23+
from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple, _GenericAlias, Union
2324

2425
import numpy as np
2526
from pydantic import BaseModel
@@ -75,7 +76,7 @@ def dump_lines(lines: List[str], fp):
7576
if isinstance(fp, str):
7677
fp = open(fp, mode="w", encoding="utf8")
7778
with fp:
78-
lines = [e + "\n" for e in lines]
79+
lines = [str(e) + "\n" for e in lines]
7980
fp.writelines(lines)
8081

8182

@@ -163,7 +164,7 @@ def table2json(path, **kwargs):
163164

164165

165166
# 将list数据存储成table格式
166-
def dump2table(data, path:str):
167+
def dump2table(data, path: str):
167168
if isinstance(data, list):
168169
data = pd.DataFrame.from_records(data)
169170
assert isinstance(data, pd.DataFrame)
@@ -215,13 +216,12 @@ def _read2list(file_path, **kwargs):
215216
rs.extend(tmp)
216217
else:
217218
rs.append(tmp)
218-
if len(file_path) == 1 and len(rs)==1 and file_path[0].endswith(".json"):
219-
rs = rs[0]
219+
if len(file_path) == 1 and len(rs) == 1 and file_path[0].endswith(".json"):
220+
rs = rs[0]
220221

221222
return rs
222223

223224

224-
225225
# 将list数据按照后缀名格式dump到文件
226226
def dump2list(data: List, file_path: str, **kwargs):
227227
create_dir_path(file_path)
@@ -259,7 +259,7 @@ def pretty_floats(obj, r=4):
259259

260260

261261
# 将data batch化输出
262-
def get_batched_data(data: Iterable, batch_size: int)->Iterable[List]:
262+
def get_batched_data(data: Iterable, batch_size: int) -> Iterable[List]:
263263
"""将数据按照batch_size分组
264264
265265
Args:
@@ -349,6 +349,8 @@ def print_info(info, target_logger=None, fix_length=128):
349349
print(star_info)
350350

351351
# 把一个dict转化到一个Union类型
352+
353+
352354
def union_parse_obj(union: _GenericAlias, d: dict):
353355
for cls in union.__args__:
354356
try:
@@ -359,6 +361,8 @@ def union_parse_obj(union: _GenericAlias, d: dict):
359361
raise Exception(f"fail to convert {d} to union {union}")
360362

361363
# 获取一个包的最新version
364+
365+
362366
def get_latest_version(package_name: str) -> str:
363367
cmd = f"pip install {package_name}=="
364368
status, output = execute_cmd(cmd)
@@ -377,6 +381,8 @@ def get_latest_version(package_name: str) -> str:
377381
return latest_version
378382

379383
# 获取一个version的下一个版本
384+
385+
380386
def get_next_version(version: str, level=0) -> str:
381387
pieces = version.split(".")
382388
idx = len(pieces) - level - 1
@@ -388,8 +394,7 @@ def get_next_version(version: str, level=0) -> str:
388394
return ".".join(pieces)
389395

390396

391-
392-
def deep_update(origin:dict, new_data:dict, inplace=True)->dict:
397+
def deep_update(origin: dict, new_data: dict, inplace=True) -> dict:
393398
"""递归跟新dict
394399
395400
Args:
@@ -401,20 +406,63 @@ def deep_update(origin:dict, new_data:dict, inplace=True)->dict:
401406
dict: 更新后的dict
402407
"""
403408
to_update = origin if inplace else copy.deepcopy(origin)
404-
405-
def _deep_update_inplace(tu:dict, nd:dict)->dict:
406-
for k,v in nd.items():
409+
410+
def _deep_update_inplace(tu: dict, nd: dict) -> dict:
411+
for k, v in nd.items():
407412
if k not in to_update:
408413
tu[k] = v
409414
else:
410415
if isinstance(v, dict) and isinstance(to_update[k], dict):
411416
_deep_update_inplace(to_update[k], v)
412417
else:
413-
tu[k] =v
414-
return tu
418+
tu[k] = v
419+
return tu
415420
return _deep_update_inplace(to_update, new_data)
416-
417-
418-
419-
420-
421+
422+
423+
def delete_paths(paths: str | List[str]):
424+
if isinstance(paths, str):
425+
paths = [paths]
426+
for path in paths:
427+
try:
428+
if os.path.isfile(path): # Check if it's a file
429+
os.remove(path)
430+
# print(f"File {path} has been deleted.")
431+
elif os.path.isdir(path): # Check if it's a directory
432+
shutil.rmtree(path)
433+
# print(f"Directory {path} has been deleted.")
434+
else:
435+
pass
436+
# print(f"No such file or directory: {path}")
437+
except Exception as e:
438+
raise e
439+
# logger.exception("e")
440+
441+
442+
def batch_process_with_save(data: Iterable, func: Callable, file_path: str, batch_size: int, **kwargs):
443+
"""将数据按照指定函数dump到文件
444+
445+
Args:
446+
data (Iterable): 待dump数据
447+
func (Callable): dump函数
448+
file_path (str): 文件路径
449+
**kwargs: dump函数参数
450+
"""
451+
stem, suffix = os.path.splitext(file_path)
452+
history_files = glob.glob(f"{stem}-[0-9]*-[0-9]*.{suffix}")
453+
acc = load(history_files)
454+
455+
logger.info(f"{len(acc)} history data loaded")
456+
# logger.info(history_files)
457+
data = data[len(acc):]
458+
459+
for idx, batch in enumerate(batchify(data, batch_size)):
460+
batch_result = list(func(batch, **kwargs))
461+
dump(batch_result, f"{stem}-{idx*batch_size}-{idx*batch_size+len(batch)}.{suffix}")
462+
acc.extend(batch_result)
463+
464+
history_files = glob.glob(f"{stem}*.{suffix}")
465+
delete_paths(history_files)
466+
467+
dump(acc, file_path)
468+
# logger.info(history_files)

tests/test_utils.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_get_next_version(self):
5252
def test_get_latest_version(self):
5353
latest_version = get_latest_version("python-snippets")
5454
print(latest_version)
55-
55+
5656
def test_deep_update(self):
5757
origin = dict(a=1, b=dict(c=1), c="c")
5858
to_update = dict(a=2, b=dict(e=2), c=dict(f="f"), k="k")
@@ -61,13 +61,18 @@ def test_deep_update(self):
6161
print(origin)
6262
self.assertEquals(updated, {'a': 2, 'b': {'c': 1, 'e': 2}, 'c': {'f': 'f'}, 'k': 'k'})
6363
self.assertEquals(origin, {'a': 1, 'b': {'c': 1}, 'c': 'c'})
64-
6564

6665
def test_load(self):
6766
data = load("data/sample.*")
6867
print(len(data))
6968
self.assertEquals(len(data), 6)
7069
data = load("/Users/chenhao/workspace/XAgents/knowledge_base/test_service/config.json")
7170
print(data)
72-
73-
71+
72+
def test_batch_process_with_save(self):
73+
data = range(20)
74+
def func(x): return (e**2 for e in x)
75+
dist_path = "./batch_process_result.txt"
76+
batch_process_with_save(data, func, dist_path, batch_size=6)
77+
if os.path.exists(dist_path):
78+
os.remove(dist_path)

0 commit comments

Comments
 (0)