15
15
import os
16
16
import pickle
17
17
import re
18
+ import shutil
18
19
import subprocess
19
20
import time
20
21
from datetime import datetime
21
22
import pandas as pd
22
- from typing import Any , Dict , Iterable , List , Sequence , Tuple , _GenericAlias , Union
23
+ from typing import Any , Callable , Dict , Iterable , List , Sequence , Tuple , _GenericAlias , Union
23
24
24
25
import numpy as np
25
26
from pydantic import BaseModel
@@ -75,7 +76,7 @@ def dump_lines(lines: List[str], fp):
75
76
if isinstance (fp , str ):
76
77
fp = open (fp , mode = "w" , encoding = "utf8" )
77
78
with fp :
78
- lines = [e + "\n " for e in lines ]
79
+ lines = [str ( e ) + "\n " for e in lines ]
79
80
fp .writelines (lines )
80
81
81
82
@@ -163,7 +164,7 @@ def table2json(path, **kwargs):
163
164
164
165
165
166
# 将list数据存储成table格式
166
- def dump2table (data , path :str ):
167
+ def dump2table (data , path : str ):
167
168
if isinstance (data , list ):
168
169
data = pd .DataFrame .from_records (data )
169
170
assert isinstance (data , pd .DataFrame )
@@ -215,13 +216,12 @@ def _read2list(file_path, **kwargs):
215
216
rs .extend (tmp )
216
217
else :
217
218
rs .append (tmp )
218
- if len (file_path ) == 1 and len (rs )== 1 and file_path [0 ].endswith (".json" ):
219
- rs = rs [0 ]
219
+ if len (file_path ) == 1 and len (rs ) == 1 and file_path [0 ].endswith (".json" ):
220
+ rs = rs [0 ]
220
221
221
222
return rs
222
223
223
224
224
-
225
225
# 将list数据按照后缀名格式dump到文件
226
226
def dump2list (data : List , file_path : str , ** kwargs ):
227
227
create_dir_path (file_path )
@@ -259,7 +259,7 @@ def pretty_floats(obj, r=4):
259
259
260
260
261
261
# 将data batch化输出
262
- def get_batched_data (data : Iterable , batch_size : int )-> Iterable [List ]:
262
+ def get_batched_data (data : Iterable , batch_size : int ) -> Iterable [List ]:
263
263
"""将数据按照batch_size分组
264
264
265
265
Args:
@@ -349,6 +349,8 @@ def print_info(info, target_logger=None, fix_length=128):
349
349
print (star_info )
350
350
351
351
# 把一个dict转化到一个Union类型
352
+
353
+
352
354
def union_parse_obj (union : _GenericAlias , d : dict ):
353
355
for cls in union .__args__ :
354
356
try :
@@ -359,6 +361,8 @@ def union_parse_obj(union: _GenericAlias, d: dict):
359
361
raise Exception (f"fail to convert { d } to union { union } " )
360
362
361
363
# 获取一个包的最新version
364
+
365
+
362
366
def get_latest_version (package_name : str ) -> str :
363
367
cmd = f"pip install { package_name } =="
364
368
status , output = execute_cmd (cmd )
@@ -377,6 +381,8 @@ def get_latest_version(package_name: str) -> str:
377
381
return latest_version
378
382
379
383
# 获取一个version的下一个版本
384
+
385
+
380
386
def get_next_version (version : str , level = 0 ) -> str :
381
387
pieces = version .split ("." )
382
388
idx = len (pieces ) - level - 1
@@ -388,8 +394,7 @@ def get_next_version(version: str, level=0) -> str:
388
394
return "." .join (pieces )
389
395
390
396
391
-
392
- def deep_update (origin :dict , new_data :dict , inplace = True )-> dict :
397
+ def deep_update (origin : dict , new_data : dict , inplace = True ) -> dict :
393
398
"""递归跟新dict
394
399
395
400
Args:
@@ -401,20 +406,63 @@ def deep_update(origin:dict, new_data:dict, inplace=True)->dict:
401
406
dict: 更新后的dict
402
407
"""
403
408
to_update = origin if inplace else copy .deepcopy (origin )
404
-
405
- def _deep_update_inplace (tu :dict , nd :dict )-> dict :
406
- for k ,v in nd .items ():
409
+
410
+ def _deep_update_inplace (tu : dict , nd : dict ) -> dict :
411
+ for k , v in nd .items ():
407
412
if k not in to_update :
408
413
tu [k ] = v
409
414
else :
410
415
if isinstance (v , dict ) and isinstance (to_update [k ], dict ):
411
416
_deep_update_inplace (to_update [k ], v )
412
417
else :
413
- tu [k ] = v
414
- return tu
418
+ tu [k ] = v
419
+ return tu
415
420
return _deep_update_inplace (to_update , new_data )
416
-
417
-
418
-
419
-
420
-
421
+
422
+
423
+ def delete_paths (paths : str | List [str ]):
424
+ if isinstance (paths , str ):
425
+ paths = [paths ]
426
+ for path in paths :
427
+ try :
428
+ if os .path .isfile (path ): # Check if it's a file
429
+ os .remove (path )
430
+ # print(f"File {path} has been deleted.")
431
+ elif os .path .isdir (path ): # Check if it's a directory
432
+ shutil .rmtree (path )
433
+ # print(f"Directory {path} has been deleted.")
434
+ else :
435
+ pass
436
+ # print(f"No such file or directory: {path}")
437
+ except Exception as e :
438
+ raise e
439
+ # logger.exception("e")
440
+
441
+
442
+ def batch_process_with_save (data : Iterable , func : Callable , file_path : str , batch_size : int , ** kwargs ):
443
+ """将数据按照指定函数dump到文件
444
+
445
+ Args:
446
+ data (Iterable): 待dump数据
447
+ func (Callable): dump函数
448
+ file_path (str): 文件路径
449
+ **kwargs: dump函数参数
450
+ """
451
+ stem , suffix = os .path .splitext (file_path )
452
+ history_files = glob .glob (f"{ stem } -[0-9]*-[0-9]*.{ suffix } " )
453
+ acc = load (history_files )
454
+
455
+ logger .info (f"{ len (acc )} history data loaded" )
456
+ # logger.info(history_files)
457
+ data = data [len (acc ):]
458
+
459
+ for idx , batch in enumerate (batchify (data , batch_size )):
460
+ batch_result = list (func (batch , ** kwargs ))
461
+ dump (batch_result , f"{ stem } -{ idx * batch_size } -{ idx * batch_size + len (batch )} .{ suffix } " )
462
+ acc .extend (batch_result )
463
+
464
+ history_files = glob .glob (f"{ stem } *.{ suffix } " )
465
+ delete_paths (history_files )
466
+
467
+ dump (acc , file_path )
468
+ # logger.info(history_files)
0 commit comments