Skip to content

Commit

Permalink
[CLUSTER] 3x3_len10
Browse files Browse the repository at this point in the history
  • Loading branch information
manuel-delverme committed Nov 19, 2021
1 parent e5eebc5 commit e531f74
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 60 deletions.
6 changes: 2 additions & 4 deletions lstmDQN/custom_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ def __init__(self, observation_space, action_space, device):

self.action_vocab_size = len(action_space.vocab)

self.model = LSTMDQN(
len(observation_space.vocab), len(action_space.vocab), self.device, action_space.sequence_length,
config.embedding_size, config.encoder_rnn_hidden_size, config.action_scorer_hidden_dim,
)
self.model = LSTMDQN(len(observation_space.vocab), len(action_space.vocab), action_space.sequence_length, config.embedding_size, config.encoder_rnn_hidden_size,
config.action_scorer_hidden_dim, )

# obs_vocab_size, action_vocab_size, device, output_length: int,
# embedding_size, encoder_rnn_hidden_size, action_scorer_hidden_dim, ):
Expand Down
7 changes: 2 additions & 5 deletions lstmDQN/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
class LSTMDQN(torch.nn.Module):
batch_size = 1

def __init__(
self, obs_vocab_size, action_vocab_size, device, output_length: int,
embedding_size, encoder_rnn_hidden_size, action_scorer_hidden_dim, ):
def __init__(self, obs_vocab_size, action_vocab_size, output_length: int, embedding_size, encoder_rnn_hidden_size, action_scorer_hidden_dim, ):
super(LSTMDQN, self).__init__()
self.device = device

self.word_embedding = torch.nn.Embedding(obs_vocab_size, embedding_size, device=self.device)
self.word_embedding = torch.nn.Embedding(obs_vocab_size, embedding_size)
self.encoder = torch.nn.GRU(embedding_size, encoder_rnn_hidden_size, batch_first=True)
self.Q_features = torch.nn.Sequential(
torch.nn.Linear(encoder_rnn_hidden_size, action_scorer_hidden_dim),
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ numpy
gym
stable-baselines3
icecream
git+https://github.com/ministry-of-silly-code/experiment_buddy.git@main#egg=experiment_buddy
experiment_buddy
86 changes: 40 additions & 46 deletions sweep.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,44 @@ parameters:
distribution: log_uniform
min: -7.76804212
max: -4.76804212

^replay_batch_size:
distribution: categorical
values:
- 32
- 64
- 128
- 256
- 1024
- 2048
- 4096

^embedding_size:
distribution: categorical
values:
- 32
- 64
- 128
- 256

^encoder_rnn_hidden_size:
distribution: categorical
values:
- 64
- 128
- 256
- 1024

^action_socrer_hidden_dim:
distribution: categorical
values:
- 64
- 128
- 256

^update_per_k_game_steps:
distribution: q_uniform
min: 1
max: 10
q: 1

^epsilon_anneal_episodes:
distribution: q_uniform
min: 1000
max: 1000000
q: 1000
^replay_batch_size:
distribution: categorical
values:
- 32
- 64
- 128
- 256
- 1024
- 2048
- 4096
^embedding_size:
distribution: categorical
values:
- 32
- 64
- 128
- 256
^encoder_rnn_hidden_size:
distribution: categorical
values:
- 64
- 128
- 256
- 1024
^action_socrer_hidden_dim:
distribution: categorical
values:
- 64
- 128
- 256
^update_per_k_game_steps:
distribution: q_uniform
min: 1
max: 10
q: 1
^epsilon_anneal_episodes:
distribution: q_uniform
min: 1000
max: 1000000
q: 1000
program: train.py
9 changes: 5 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def train(tb):

if __name__ == '__main__':
experiment_buddy.register_defaults(vars(config))
PROC_NUM = 1
PROC_NUM = 10
# HOST = "mila" if config.user == "esac" else ""
HOST = ""
YAML_FILE = "" # "sweep.yml" # "env_suite.yml"
tb = experiment_buddy.deploy(host=HOST, sweep_yaml=YAML_FILE, proc_num=PROC_NUM, wandb_kwargs={"mode": "disabled" if config.DEBUG else "online", "entity": "rl-sql"})
HOST = "mila"
RUN_SWEEP = True
tb = experiment_buddy.deploy(host=HOST, sweep_yaml="sweep.yml" if RUN_SWEEP else "", proc_num=PROC_NUM,
wandb_kwargs={"mode": "disabled" if config.DEBUG else "online", "entity": "rl-sql"})
train(tb)

0 comments on commit e531f74

Please sign in to comment.