Skip to content

Commit

Permalink
add cpu support via map_location in test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirili4ik authored Oct 29, 2021
1 parent 01e4689 commit f973e72
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

def main(config, out_file):
logger = config.get_logger("test")


# define cpu or gpu if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# text_encoder
text_encoder = CTCCharTextEncoder.get_simple_alphabet()

Expand All @@ -30,14 +33,13 @@ def main(config, out_file):
logger.info(model)

logger.info("Loading checkpoint: {} ...".format(config.resume))
checkpoint = torch.load(config.resume)
checkpoint = torch.load(config.resume, map_location=device)
state_dict = checkpoint["state_dict"]
if config["n_gpu"] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)

# prepare model for testing
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

Expand Down

0 comments on commit f973e72

Please sign in to comment.