Skip to content

Commit

Permalink
quick change
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Oct 17, 2023
1 parent cbbdc8b commit e69b317
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions cleanrl/qdagger_dqn_atari_impalacnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class Args:
# QDagger specific arguments
teacher_policy_hf_repo: str = None
"""the huggingface repo of the teacher policy"""
teacher_model_exp_name: str = "dqn_atari"
"""the experiment name of the teacher model"""
teacher_eval_episodes: int = 10
"""the number of episodes to run the teacher policy evaluate"""
teacher_steps: int = 500000
Expand Down Expand Up @@ -206,7 +208,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits):
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
if args.teacher_policy_hf_repo is None:
args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-dqn_atari-seed1"
args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-{args.teacher_model_exp_name}-seed1"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
Expand Down Expand Up @@ -246,7 +248,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits):
target_network.load_state_dict(q_network.state_dict())

# QDAGGER LOGIC:
teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename="dqn_atari.cleanrl_model")
teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename=f"{args.teacher_model_exp_name}.cleanrl_model")
teacher_model = TeacherModel(envs).to(device)
teacher_model.load_state_dict(torch.load(teacher_model_path, map_location=device))
teacher_model.eval()
Expand Down
6 changes: 4 additions & 2 deletions cleanrl/qdagger_dqn_atari_jax_impalacnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class Args:
# QDagger specific arguments
teacher_policy_hf_repo: str = None
"""the huggingface repo of the teacher policy"""
teacher_model_exp_name: str = "dqn_atari_jax"
"""the experiment name of the teacher model"""
teacher_eval_episodes: int = 10
"""the number of episodes to run the teacher policy evaluate"""
teacher_steps: int = 500000
Expand Down Expand Up @@ -199,7 +201,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
if args.teacher_policy_hf_repo is None:
args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-dqn_atari-seed1"
args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-{args.teacher_model_exp_name}-seed1"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
Expand Down Expand Up @@ -242,7 +244,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
q_network.apply = jax.jit(q_network.apply)

# QDAGGER LOGIC:
teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename="dqn_atari_jax.cleanrl_model")
teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename=f"{args.teacher_model_exp_name}.cleanrl_model")
teacher_model = TeacherModel(action_dim=envs.single_action_space.n)
teacher_model_key = jax.random.PRNGKey(args.seed)
teacher_params = teacher_model.init(teacher_model_key, envs.observation_space.sample())
Expand Down

0 comments on commit e69b317

Please sign in to comment.