Skip to content

Commit

Permalink
[CLUSTER] use_entropy_
Browse files Browse the repository at this point in the history
  • Loading branch information
d3sm0 committed Jun 26, 2021
1 parent 9e74193 commit 137d962
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 4 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@
DEBUG = sys.gettrace() is not None
PROC_NUM = 1
HOST = "mila" if user in ("d3sm0", "esac") else ""
YAML_FILE = "env_suite.yml"
YAML_FILE = "" # "env_suite.yml"
tb = experiment_buddy.deploy(host=HOST, sweep_yaml=YAML_FILE, proc_num=PROC_NUM, wandb_kwargs={"mode": "disabled" if DEBUG else "online", "entity": "rl-sql"})
Binary file added dataset/sqli.csv
Binary file not shown.
Binary file added dataset/sqliv2.csv
Binary file not shown.
6 changes: 3 additions & 3 deletions ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def update(self, rollouts):

# Reshape to do in a single forward pass for all steps
values, action_log_probs, parsed_actions, concentration = self.actor_critic.evaluate_actions(obs_batch, actions_batch)
# $entropy = - torch.einsum('btx,btx->bt',torch.exp(action_log_probs), action_log_probs).mean()
entropy = - (torch.log(concentration + 1e-6) * concentration).sum(-1).mean()
entropy = - torch.einsum('btx,btx->bt',torch.exp(action_log_probs), action_log_probs).mean()
# entropy = - (torch.log(concentration + 1e-6) * concentration).sum(-1).mean()
action_log_probs = torch.einsum("btx,btx->bt", action_log_probs, parsed_actions)
old_action_log_probs_batch = torch.einsum("btx,btx->bt", old_action_log_probs_batch, parsed_actions)
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
action_loss = - (concentration * ratio * adv_targ).mean()
action_loss = - (ratio * adv_targ).mean()
# surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
# action_loss = -torch.min(surr1, surr2).mean()

Expand Down
38 changes: 38 additions & 0 deletions trash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
episode_distances.clear()

if network_updates % config.log_query_interval == 0 and network_updates:
data.extend([[network_updates, rollout_step, q, float(r), str(o), i["template"]] for q, r, o, i in
zip(queries, reward, obs, infos)])

# for info in infos:
# if 'episode' in info.keys():
# # It's done.
# r = info['episode']['r'] # .detach().numpy()
# episode_rewards.append(r)
# solved = info["solved"]
# success_rate[info['columns']].append(solved)
# # agent.entropy_coef /= (1 + float(success_rate[-1]))

# episode_distances.append(info['similarity'])

# config.tb.run.log({"train_queries": wandb.Table(columns=["network_update", "rollout_step", "query", "reward", "observation", "template"], data=data)})
config.tb.add_histogram("train/log_prob", action_logprob, global_step=network_updates)
config.tb.add_histogram('train/log_prob_per_action',
np.histogram(np.arange(action_logprob.shape[0]), weights=action_logprob),
global_step=network_updates)
config.tb.add_scalar("train/fps", int(total_num_steps / (end - start)), global_step=network_updates)
config.tb.add_scalar("train/avg_rw", np.mean(episode_rewards), global_step=network_updates)
config.tb.add_scalar("train/max_return", np.max(episode_rewards), global_step=network_updates)
config.tb.add_scalar("train/entropy", dist_entropy, global_step=network_updates)
config.tb.add_scalar("train/mean_distance", np.mean(episode_distances), global_step=network_updates)
config.tb.add_scalar("train/value_loss", value_loss, global_step=network_updates)
config.tb.add_scalar("train/action_loss", action_loss, global_step=network_updates)
for idx, sr in enumerate(success_rate):
if len(sr):
config.tb.add_scalar(f"train/success_rate{idx + 1}", np.mean(sr), global_step=network_updates)

if len(success_rate[-1]) == success_rate[-1].maxlen and np.mean(success_rate[-1]) >= 0.75:
successes += 1
if successes > 10:
print("Done :)")
return

0 comments on commit 137d962

Please sign in to comment.