From e531f742a3ad9d7953e196d547a52ab72ec3c3f0 Mon Sep 17 00:00:00 2001 From: manuel Date: Fri, 19 Nov 2021 20:21:15 +0100 Subject: [PATCH] [CLUSTER] 3x3_len10 --- lstmDQN/custom_agent.py | 6 +-- lstmDQN/model.py | 7 +--- requirements.txt | 2 +- sweep.yml | 86 +++++++++++++++++++---------------------- train.py | 9 +++-- 5 files changed, 50 insertions(+), 60 deletions(-) diff --git a/lstmDQN/custom_agent.py b/lstmDQN/custom_agent.py index 759b08b..183e624 100644 --- a/lstmDQN/custom_agent.py +++ b/lstmDQN/custom_agent.py @@ -14,10 +14,8 @@ def __init__(self, observation_space, action_space, device): self.action_vocab_size = len(action_space.vocab) - self.model = LSTMDQN( - len(observation_space.vocab), len(action_space.vocab), self.device, action_space.sequence_length, - config.embedding_size, config.encoder_rnn_hidden_size, config.action_scorer_hidden_dim, - ) + self.model = LSTMDQN(len(observation_space.vocab), len(action_space.vocab), action_space.sequence_length, config.embedding_size, config.encoder_rnn_hidden_size, + config.action_scorer_hidden_dim, ) # obs_vocab_size, action_vocab_size, device, output_length: int, # embedding_size, encoder_rnn_hidden_size, action_scorer_hidden_dim, ): diff --git a/lstmDQN/model.py b/lstmDQN/model.py index 0f02c90..71c93f3 100644 --- a/lstmDQN/model.py +++ b/lstmDQN/model.py @@ -4,13 +4,10 @@ class LSTMDQN(torch.nn.Module): batch_size = 1 - def __init__( - self, obs_vocab_size, action_vocab_size, device, output_length: int, - embedding_size, encoder_rnn_hidden_size, action_scorer_hidden_dim, ): + def __init__(self, obs_vocab_size, action_vocab_size, output_length: int, embedding_size, encoder_rnn_hidden_size, action_scorer_hidden_dim, ): super(LSTMDQN, self).__init__() - self.device = device - self.word_embedding = torch.nn.Embedding(obs_vocab_size, embedding_size, device=self.device) + self.word_embedding = torch.nn.Embedding(obs_vocab_size, embedding_size) self.encoder = torch.nn.GRU(embedding_size, encoder_rnn_hidden_size, batch_first=True) self.Q_features = torch.nn.Sequential( torch.nn.Linear(encoder_rnn_hidden_size, action_scorer_hidden_dim), diff --git a/requirements.txt b/requirements.txt index d7f1aff..46de3d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ numpy gym stable-baselines3 icecream -git+https://github.com/ministry-of-silly-code/experiment_buddy.git@main#egg=experiment_buddy +experiment_buddy diff --git a/sweep.yml b/sweep.yml index d068f94..a12cf14 100644 --- a/sweep.yml +++ b/sweep.yml @@ -8,50 +8,44 @@ parameters: distribution: log_uniform min: -7.76804212 max: -4.76804212 - -^replay_batch_size: - distribution: categorical - values: - - 32 - - 64 - - 128 - - 256 - - 1024 - - 2048 - - 4096 - -^embedding_size: - distribution: categorical - values: - - 32 - - 64 - - 128 - - 256 - -^encoder_rnn_hidden_size: - distribution: categorical - values: - - 64 - - 128 - - 256 - - 1024 - -^action_socrer_hidden_dim: - distribution: categorical - values: - - 64 - - 128 - - 256 - -^update_per_k_game_steps: - distribution: q_uniform - min: 1 - max: 10 - q: 1 - -^epsilon_anneal_episodes: - distribution: q_uniform - min: 1000 - max: 1000000 - q: 1000 + ^replay_batch_size: + distribution: categorical + values: + - 32 + - 64 + - 128 + - 256 + - 1024 + - 2048 + - 4096 + ^embedding_size: + distribution: categorical + values: + - 32 + - 64 + - 128 + - 256 + ^encoder_rnn_hidden_size: + distribution: categorical + values: + - 64 + - 128 + - 256 + - 1024 + ^action_socrer_hidden_dim: + distribution: categorical + values: + - 64 + - 128 + - 256 + ^update_per_k_game_steps: + distribution: q_uniform + min: 1 + max: 10 + q: 1 + ^epsilon_anneal_episodes: + distribution: q_uniform + min: 1000 + max: 1000000 + q: 1000 program: train.py diff --git a/train.py b/train.py index 1388553..1995936 100644 --- a/train.py +++ b/train.py @@ -72,9 +72,10 @@ def train(tb): if __name__ == '__main__': experiment_buddy.register_defaults(vars(config)) - PROC_NUM = 1 + PROC_NUM = 10 # HOST = "mila" if config.user == "esac" else "" - HOST = "" - YAML_FILE = "" # "sweep.yml" # "env_suite.yml" - tb = experiment_buddy.deploy(host=HOST, sweep_yaml=YAML_FILE, proc_num=PROC_NUM, wandb_kwargs={"mode": "disabled" if config.DEBUG else "online", "entity": "rl-sql"}) + HOST = "mila" + RUN_SWEEP = True + tb = experiment_buddy.deploy(host=HOST, sweep_yaml="sweep.yml" if RUN_SWEEP else "", proc_num=PROC_NUM, + wandb_kwargs={"mode": "disabled" if config.DEBUG else "online", "entity": "rl-sql"}) train(tb)