forked from chinese-poetry/chinese-poetry
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
72 lines (61 loc) · 2.29 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import json
import os
DATAS_CONFIG = "./loader/datas.json"
class PlainDataLoader():
def __init__(self, config_path: str=DATAS_CONFIG) -> None:
self._path = config_path
with open(config_path, 'r', encoding='utf-8') as config:
data = json.load(config)
self.top_level_path:str = data["cp_path"]
self.datasets:dict = data["datasets"]
self.id_table = {
v["id"]: k for (k, v) in self.datasets.items()
}
def body_extractor(self, target: str) -> list:
if target not in self.datasets:
print(f"{target} is not included in datas.json as a dataset")
return None
configs = self.datasets[target]
tag = configs["tag"]
body = [] # may get a bit huge...
full_path = os.path.join(self.top_level_path, configs["path"])
if os.path.isfile(full_path): # single file json
with open(full_path, mode='r', encoding='utf-8') as file:
data = json.load(file)
for poem in data:
body += poem[tag]
return body
# a dir, probably with a skip list
subpaths = os.listdir(full_path)
for filename in subpaths:
if filename in configs["excludes"]:
continue
with open(os.path.join(full_path, filename), mode='r', encoding='utf-8') as file:
data = json.load(file)
for poem in data:
body += poem[tag]
return body
def extract_from_multiple(self, targets: list) -> list:
results = []
for target in targets:
results += self.body_extractor(target)
return results
def extract_with_ids(self, ids: list) -> list:
results = []
for id in ids:
results += self.body_extractor(
self.id_table[id]
)
return results
if __name__ == "__main__":
loader = PlainDataLoader()
print(loader.id_table)
print(
loader.body_extractor("wudai-huajianji")[-1]
)
print(
len(loader.extract_from_multiple(["wudai-huajianji", "wudai-nantang"]))
)
print(
loader.extract_with_ids([0, 1, 2])
)