@@ -82,18 +82,17 @@ def evaluate(model: torch.nn.Module, dataset: Im2LatexDataset, args: Munch, num_
82
82
parser .add_argument ('-c' , '--checkpoint' , default = 'checkpoints/weights.pth' , type = str , help = 'path to model checkpoint' )
83
83
parser .add_argument ('-d' , '--data' , default = 'dataset/data/val.pkl' , type = str , help = 'Path to Dataset pkl file' )
84
84
parser .add_argument ('--no-cuda' , action = 'store_true' , help = 'Use CPU' )
85
- parser .add_argument ('-b' , '--batchsize' , type = int , default = None , help = 'Batch size' )
85
+ parser .add_argument ('-b' , '--batchsize' , type = int , default = 10 , help = 'Batch size' )
86
86
parser .add_argument ('--debug' , action = 'store_true' , help = 'DEBUG' )
87
87
88
88
parsed_args = parser .parse_args ()
89
89
with parsed_args .config as f :
90
90
params = yaml .load (f , Loader = yaml .FullLoader )
91
91
args = parse_args (Munch (params ))
92
- if parsed_args .batchsize is not None :
93
- args .testbatchsize = parsed_args .batchsize
92
+ args .testbatchsize = parsed_args .batchsize
94
93
args .wandb = False
95
94
logging .getLogger ().setLevel (logging .DEBUG if parsed_args .debug else logging .WARNING )
96
- seed_everything (args .seed )
95
+ seed_everything (args .seed if 'seed' in args else 42 )
97
96
model = get_model (args )
98
97
if parsed_args .checkpoint is not None :
99
98
model .load_state_dict (torch .load (parsed_args .checkpoint , args .device ))
0 commit comments