Skip to content

Commit 6c28813

Browse files
committed
specify map_location (lukas-blecher#4)
1 parent 21bdc10 commit 6c28813

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pix2tex.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def initialize(arguments):
3030
args.device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu'
3131

3232
model = get_model(args)
33-
model.load_state_dict(torch.load(args.checkpoint))
33+
model.load_state_dict(torch.load(args.checkpoint, map_location=args.device))
3434
tokenizer = PreTrainedTokenizerFast(tokenizer_file=args.tokenizer)
3535
return args, model, tokenizer
3636

0 commit comments

Comments
 (0)