Skip to content

Commit

Permalink
Fix small bug with status
Browse files Browse the repository at this point in the history
  • Loading branch information
lcswillems committed Aug 30, 2018
1 parent f583090 commit 795a87d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
15 changes: 10 additions & 5 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,20 @@

preprocess_obss = utils.ObssPreprocessor(save_dir, envs[0].observation_space)

# Load training status

try:
status = utils.load_status(save_dir)
except OSError:
status = {"num_frames": 0, "update": 0}

# Define actor-critic model

if utils.model_exists(save_dir):
try:
acmodel = utils.load_model(save_dir)
status = utils.load_status(save_dir)
logger.info("Model successfully loaded\n")
else:
except OSError:
acmodel = ACModel(preprocess_obss.obs_space, envs[0].action_space, not args.no_instr, not args.no_mem)
status = {"num_frames": 0, "update": 0}
logger.info("Model successfully created\n")
logger.info("{}\n".format(acmodel))

Expand Down Expand Up @@ -175,7 +180,7 @@
header += ["return_" + key for key in return_per_episode.keys()]
data += return_per_episode.values()

if not(status["num_frames"]):
if status["num_frames"] == 0:
csv_writer.writerow(header)
csv_writer.writerow(data)
csv_file.flush()
Expand Down
4 changes: 0 additions & 4 deletions utils/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
def get_model_path(save_dir):
return os.path.join(save_dir, "model.pt")

def model_exists(save_dir):
path = get_model_path(save_dir)
return os.path.exists(path)

def load_model(save_dir):
path = get_model_path(save_dir)
model = torch.load(path)
Expand Down

0 comments on commit 795a87d

Please sign in to comment.