Skip to content

make sure top-level timer is closed before writing #3631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.timers import hierarchical_timer
from mlagents_envs.timers import hierarchical_timer, get_timer_tree
from mlagents.logging_util import create_logger


Expand Down Expand Up @@ -329,6 +329,18 @@ def run_training(run_seed: int, options: RunOptions) -> None:
tc.start_learning(env_manager)
finally:
env_manager.close()
write_timing_tree(summaries_dir, options.run_id)


def write_timing_tree(summaries_dir: str, run_id: str) -> None:
timing_path = f"{summaries_dir}/{run_id}_timers.json"
try:
with open(timing_path, "w") as f:
json.dump(get_timer_tree(), f, indent=4)
except FileNotFoundError:
logging.warning(
f"Unable to save to {timing_path}. Make sure the directory exists"
)


def create_sampler_manager(sampler_config, run_seed=None):
Expand Down
14 changes: 1 addition & 13 deletions ml-agents/mlagents/trainers/trainer_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import os
import sys
import json
import logging
from typing import Dict, Optional, Set
from collections import defaultdict
Expand All @@ -18,7 +17,7 @@
UnityCommunicationException,
)
from mlagents.trainers.sampler_class import SamplerManager
from mlagents_envs.timers import hierarchical_timer, get_timer_tree, timed
from mlagents_envs.timers import hierarchical_timer, timed
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.trainers.trainer_util import TrainerFactory
Expand Down Expand Up @@ -106,16 +105,6 @@ def _save_model_when_interrupted(self):
)
self._save_model()

def _write_timing_tree(self) -> None:
timing_path = f"{self.summaries_dir}/{self.run_id}_timers.json"
try:
with open(timing_path, "w") as f:
json.dump(get_timer_tree(), f, indent=4)
except FileNotFoundError:
self.logger.warning(
f"Unable to save to {timing_path}. Make sure the directory exists"
)

def _export_graph(self):
"""
Exports latest saved models to .nn format for Unity embedding.
Expand Down Expand Up @@ -231,7 +220,6 @@ def start_learning(self, env_manager: EnvManager) -> None:
pass
if self.train_model:
self._export_graph()
self._write_timing_tree()

def end_trainer_episodes(
self, env: EnvManager, lessons_incremented: Dict[str, bool]
Expand Down