-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
45 lines (37 loc) · 1.1 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from glob import glob
from typing import Iterable, Optional
import yaml
from pydantic import BaseModel
class Config(BaseModel):
# debug: bool = False
batchsize: int
# ignore_index: int = -100
epochs: int
learning_rate: float
# dropout: float = 0.1
# transformer_activation: str = "relu"
# # betas: List[float] = [0.9, 0.999]
limit_train: int = 1000000000
limit_dev: int = 1000000000
training_data: list[str]
dev_data: list[str]
test_data: list[str]
supertag_vocabulary_filename: str = "supertag_vocabulary.txt"
model_filename: Optional[str] = None
def expand_filenames(self, dataset: list[str]) -> list[str]:
"""
Call as e.g. expand_filenames(config.training_data).
:param dataset:
:return:
"""
ret = []
for globstr in dataset:
names = glob(globstr)
ret.extend(names)
ret.sort()
return ret
@staticmethod
def load(filename) -> "Config":
with open("config.yml", "r") as f:
config = Config(**yaml.safe_load(f))
return config