forked from ziqi-jin/finetune-anything
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
59 lines (53 loc) · 2.72 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
'''
@copyright ziqi-jin
'''
import argparse
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from datasets import get_dataset
from losses import get_losses
from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner
supported_tasks = ['detection', 'semantic_seg', 'instance_seg']
parser = argparse.ArgumentParser()
parser.add_argument('--task_name', default='semantic_seg', type=str)
parser.add_argument('--cfg', default=None, type=str)
parser.add_argument('--ckp', required=True, type=str)
parser.add_argument('--dump-path', default=None, type=str)
parser.add_argument('--img', default=None, type=str)
parser.add_argument('--prompt', default=None, type=str)
if __name__ == '__main__':
args = parser.parse_args()
task_name = args.task_name
if args.cfg is not None:
config = OmegaConf.load(args.cfg)
else:
assert task_name in supported_tasks, "Please input the supported task name."
config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name))
train_cfg = config.train
val_cfg = config.val
test_cfg = config.test
# train_dataset = get_dataset(train_cfg.dataset)
# train_loader = DataLoader(train_dataset, batch_size=train_cfg.bs, shuffle=True, num_workers=train_cfg.num_workers,
# drop_last=train_cfg.drop_last)
train_loader = None
if (args.img is not None and args.prompt is not None):
val_loader = None
else:
val_dataset = get_dataset(val_cfg.dataset)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0,
drop_last=val_cfg.drop_last)
losses = get_losses(losses=train_cfg.losses)
# according the model name to get the adapted model
model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params)
opt_params = get_opt_pamams(model, lr_list=train_cfg.opt_params.lr_list, group_keys=train_cfg.opt_params.group_keys,
wd_list=train_cfg.opt_params.wd_list)
optimizer = get_optimizer(opt_name=train_cfg.opt_name, params=opt_params, lr=train_cfg.opt_params.lr_default,
momentum=train_cfg.opt_params.momentum, weight_decay=train_cfg.opt_params.wd_default)
scheduler = get_scheduler(optimizer=optimizer, lr_scheduler=train_cfg.scheduler_name)
runner = get_runner(train_cfg.runner_name)(model, optimizer, losses, train_loader, val_loader, scheduler)
runner.load_checkpoint(args.ckp)
if (args.img is not None and args.prompt is not None):
runner.run_one_image(args.img, args.prompt)
else:
res = runner._eval(train_cfg, dump_dir=args.dump_path, return_recall=True)
print(res)