Skip to content

Commit c8c45a1

Browse files
committed
添加文件读写函数
1 parent 5e3571d commit c8c45a1

File tree

4 files changed

+64
-6
lines changed

4 files changed

+64
-6
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,4 @@ upload_package.sh
134134

135135
test.ipynb
136136
tests/agi-tools.code-workspace
137+
.DS_Store

setup.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from setuptools import find_packages, setup
1616

17-
from snippets.utils import get_latest_version, get_next_version
17+
from snippets.utils import get_latest_version, get_next_version, read2list
1818

1919
REQ = [
2020
"tqdm",
@@ -23,6 +23,11 @@
2323
"click"
2424
]
2525

26+
def get_install_req():
27+
req = read2list("requirements.txt")
28+
return req
29+
30+
2631

2732
if __name__ == "__main__":
2833
name = "python-snippets"
@@ -32,7 +37,8 @@
3237
latest_version = get_latest_version(name)
3338
version = get_next_version(latest_version)
3439
print(f"version: {version}")
35-
40+
install_req = get_install_req()
41+
print(f"install_req: {install_req}")
3642
setup(
3743
name=name,
3844
version=version,

snippets/mixin.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
'''
4+
@Time : 2023/10/16 14:09:07
5+
@Author : ChenHao
6+
@Contact : jerrychen1990@gmail.com
7+
'''
8+
9+
from utils import jload
10+
from typing import Union
11+
12+
13+
class ConfigMixin:
14+
@classmethod
15+
def from_config(cls, config: Union[dict, str]):
16+
if isinstance(config, str):
17+
if config.endswith(".json"):
18+
config = jload(config)
19+
else:
20+
raise ValueError(f"{config} is not a valid config file")
21+
instance = cls(**config)
22+
return instance

snippets/utils.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,6 @@ def get_gen(f):
152152
return list(gen)
153153

154154
# table类的文件转化为list of dict
155-
156-
157155
def table2json(path):
158156
if path.endswith("csv"):
159157
df = pd.read_csv(path)
@@ -163,9 +161,24 @@ def table2json(path):
163161
records = df.to_dict(orient="records")
164162
return records
165163

166-
# 一行一行地读取文件内容
164+
165+
# 将list数据存储成table格式
166+
def dump2table(data, path):
167+
if isinstance(data, list):
168+
data = pd.DataFrame.from_records(data)
169+
assert isinstance(data, pd.DataFrame)
170+
df = data
171+
172+
if path.endswith(".csv"):
173+
df.to_csv(path, index=False)
174+
elif path.endswith(".xlsx"):
175+
df.to_excel(path, index=False)
176+
else:
177+
raise Exception(f"Unknown file format: {path}")
178+
167179

168180

181+
# 一行一行地读取文件内容
169182
def load_lines(fp, return_generator=False):
170183
if isinstance(fp, str):
171184
fp = open(fp, mode="r", encoding="utf8")
@@ -175,7 +188,7 @@ def load_lines(fp, return_generator=False):
175188
return (e.strip() for e in lines if e)
176189
return [e.strip() for e in lines if e]
177190

178-
191+
# 根据后缀名读取list数据
179192
def read2list(file_path: str, **kwargs) -> List[Union[str, dict]]:
180193
name, surfix = split_surfix(file_path)
181194
if surfix == ".json":
@@ -189,6 +202,22 @@ def read2list(file_path: str, **kwargs) -> List[Union[str, dict]]:
189202
else:
190203
logger.warn(f"unkown surfix:{surfix}, read as txt")
191204
return load_lines(file_path, **kwargs)
205+
206+
207+
# 将list数据按照后缀名格式dump到文件
208+
def dump_list(data:List, file_path: str, **kwargs):
209+
name, surfix = split_surfix(file_path)
210+
if surfix == ".json":
211+
return jdump(data, file_path, **kwargs)
212+
if surfix == ".jsonl":
213+
return jdump_lines(data, file_path, **kwargs)
214+
if surfix in ["xlsx", "csv"]:
215+
return dump2table(data, file_path)
216+
if surfix in ["txt"]:
217+
return dump_lines(file_path, **kwargs)
218+
else:
219+
logger.warn(f"unkown surfix:{surfix}, dump as txt")
220+
return load_lines(file_path, **kwargs)
192221

193222

194223
# 递归将obj中的float做精度截断

0 commit comments

Comments
 (0)