-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
56 lines (48 loc) ยท 1.56 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import transformers
import yaml
from scipy.stats import pearsonr
def compute_pearson_correlation(
pred: transformers.trainer_utils.EvalPrediction,
) -> dict:
"""
ํผ์ด์จ ์๊ด ๊ณ์๋ฅผ ๊ณ์ฐํด์ฃผ๋ ํจ์
Args:
pred (torch.Tensor): ๋ชจ๋ธ์ ์์ธก๊ฐ๊ณผ ๋ ์ด๋ธ์ ํฌํจํ ๋ฐ์ดํฐ
Returns:
perason_correlation (dict): ์
๋ ฅ๊ฐ์ ํตํด ๊ณ์ฐํ ํผ์ด์จ ์๊ด ๊ณ์
"""
preds = pred.predictions.flatten()
labels = pred.label_ids.flatten()
perason_correlation = {"pearson_correlation": pearsonr(preds, labels)[0]}
return perason_correlation
def seed_everything(seed: int) -> None:
"""
๋ชจ๋ธ์์ ์ฌ์ฉํ๋ ๋ชจ๋ ๋๋ค ์๋๋ฅผ ๊ณ ์ ํด์ฃผ๋ ํจ์
Args:
seed (int): ์๋ ๊ณ ์ ์ ์ฌ์ฉํ ์ ์๊ฐ
Returns:
None
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.deterministic = True
cudnn.benchmark = True
def load_yaml(path: str) -> dict:
"""
๋ชจ๋ธ ํ๋ จ, ์์ธก์ ์ฌ์ฉํ yaml ํ์ผ์ ๋ถ๋ฌ์ค๋ ํจ์
Args:
path (str): ๋ถ๋ฌ์ฌ yaml ํ์ผ์ ๊ฒฝ๋ก
Returns:
loaded_yaml (dict): ์ง์ ํ ๊ฒฝ๋ก์์ ๋ถ๋ฌ์จ yaml ํ์ผ ๋ด์ฉ
"""
with open(path, "r") as f:
loaded_yaml = yaml.load(f, Loader=yaml.FullLoader)
return loaded_yaml