From e69b31713b958bbde94b8dbfb3bb273f04ddb6e4 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 16 Oct 2023 21:57:57 -0400 Subject: [PATCH] quick change --- cleanrl/qdagger_dqn_atari_impalacnn.py | 6 ++++-- cleanrl/qdagger_dqn_atari_jax_impalacnn.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cleanrl/qdagger_dqn_atari_impalacnn.py b/cleanrl/qdagger_dqn_atari_impalacnn.py index 1c694d43b..1ea312923 100644 --- a/cleanrl/qdagger_dqn_atari_impalacnn.py +++ b/cleanrl/qdagger_dqn_atari_impalacnn.py @@ -86,6 +86,8 @@ class Args: # QDagger specific arguments teacher_policy_hf_repo: str = None """the huggingface repo of the teacher policy""" + teacher_model_exp_name: str = "dqn_atari" + """the experiment name of the teacher model""" teacher_eval_episodes: int = 10 """the number of episodes to run the teacher policy evaluate""" teacher_steps: int = 500000 @@ -206,7 +208,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits): args = tyro.cli(Args) assert args.num_envs == 1, "vectorized envs are not supported at the moment" if args.teacher_policy_hf_repo is None: - args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-dqn_atari-seed1" + args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-{args.teacher_model_exp_name}-seed1" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: import wandb @@ -246,7 +248,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits): target_network.load_state_dict(q_network.state_dict()) # QDAGGER LOGIC: - teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename="dqn_atari.cleanrl_model") + teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename=f"{args.teacher_model_exp_name}.cleanrl_model") teacher_model = TeacherModel(envs).to(device) teacher_model.load_state_dict(torch.load(teacher_model_path, map_location=device)) teacher_model.eval() diff --git a/cleanrl/qdagger_dqn_atari_jax_impalacnn.py b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py index f2ebd1db0..f160bed20 100644 --- a/cleanrl/qdagger_dqn_atari_jax_impalacnn.py +++ b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py @@ -89,6 +89,8 @@ class Args: # QDagger specific arguments teacher_policy_hf_repo: str = None """the huggingface repo of the teacher policy""" + teacher_model_exp_name: str = "dqn_atari_jax" + """the experiment name of the teacher model""" teacher_eval_episodes: int = 10 """the number of episodes to run the teacher policy evaluate""" teacher_steps: int = 500000 @@ -199,7 +201,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): args = tyro.cli(Args) assert args.num_envs == 1, "vectorized envs are not supported at the moment" if args.teacher_policy_hf_repo is None: - args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-dqn_atari-seed1" + args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-{args.teacher_model_exp_name}-seed1" run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: import wandb @@ -242,7 +244,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): q_network.apply = jax.jit(q_network.apply) # QDAGGER LOGIC: - teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename="dqn_atari_jax.cleanrl_model") + teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename=f"{args.teacher_model_exp_name}.cleanrl_model") teacher_model = TeacherModel(action_dim=envs.single_action_space.n) teacher_model_key = jax.random.PRNGKey(args.seed) teacher_params = teacher_model.init(teacher_model_key, envs.observation_space.sample())