Skip to content

Commit 0dfc102

Browse files
committed
make eval work with minimal config
1 parent 9dbd04b commit 0dfc102

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

eval.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,17 @@ def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, num_
8282
parser.add_argument('-c', '--checkpoint', default='checkpoints/weights.pth', type=str, help='path to model checkpoint')
8383
parser.add_argument('-d', '--data', default='dataset/data/val.pkl', type=str, help='Path to Dataset pkl file')
8484
parser.add_argument('--no-cuda', action='store_true', help='Use CPU')
85-
parser.add_argument('-b', '--batchsize', type=int, default=None, help='Batch size')
85+
parser.add_argument('-b', '--batchsize', type=int, default=10, help='Batch size')
8686
parser.add_argument('--debug', action='store_true', help='DEBUG')
8787

8888
parsed_args = parser.parse_args()
8989
with parsed_args.config as f:
9090
params = yaml.load(f, Loader=yaml.FullLoader)
9191
args = parse_args(Munch(params))
92-
if parsed_args.batchsize is not None:
93-
args.testbatchsize = parsed_args.batchsize
92+
args.testbatchsize = parsed_args.batchsize
9493
args.wandb = False
9594
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
96-
seed_everything(args.seed)
95+
seed_everything(args.seed if 'seed' in args else 42)
9796
model = get_model(args)
9897
if parsed_args.checkpoint is not None:
9998
model.load_state_dict(torch.load(parsed_args.checkpoint, args.device))

settings/config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
batchsize: 10
22
bos_token: 1
33
channels: 1
4+
debug: false
45
device: cuda
56
dim: 256
67
encoder_depth: 4

utils/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def parse_args(args, **kwargs):
4949
args.wandb = not kwargs.debug and not args.debug
5050
args.device = 'cuda' if torch.cuda.is_available() and not kwargs.no_cuda else 'cpu'
5151
args.max_dimensions = [args.max_width, args.max_height]
52-
args.out_path = os.path.join(args.model_path, args.name)
53-
os.makedirs(args.out_path, exist_ok=True)
52+
if 'model_path' in args:
53+
args.out_path = os.path.join(args.model_path, args.name)
54+
os.makedirs(args.out_path, exist_ok=True)
5455
return args
5556

5657

0 commit comments

Comments
 (0)