Skip to content

Commit 5d5c4ea

Browse files
author
Ervin T
authored
Pytorch port of SAC (#4219)
1 parent da3a7f8 commit 5d5c4ea

File tree

8 files changed

+871
-85
lines changed

8 files changed

+871
-85
lines changed

experiment_torch.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def run_experiment(
1414
name: str,
1515
steps: int,
1616
use_torch: bool,
17+
algo: str,
1718
num_torch_threads: int,
1819
use_gpu: bool,
1920
num_envs: int = 1,
@@ -32,6 +33,7 @@ def run_experiment(
3233
name,
3334
str(steps),
3435
str(use_torch),
36+
algo,
3537
str(num_torch_threads),
3638
str(num_envs),
3739
str(use_gpu),
@@ -46,7 +48,7 @@ def run_experiment(
4648
if config_name is None:
4749
config_name = name
4850
run_options = parse_command_line(
49-
[f"config/ppo/{config_name}.yaml", "--num-envs", f"{num_envs}"]
51+
[f"config/{algo}/{config_name}.yaml", "--num-envs", f"{num_envs}"]
5052
)
5153
run_options.checkpoint_settings.run_id = (
5254
f"{name}_test_" + str(steps) + "_" + ("torch" if use_torch else "tf")
@@ -87,20 +89,29 @@ def run_experiment(
8789
tc_advance_total = tc_advance["total"]
8890
tc_advance_count = tc_advance["count"]
8991
if use_torch:
90-
update_total = update["TorchPPOOptimizer.update"]["total"]
92+
if algo == "ppo":
93+
update_total = update["TorchPPOOptimizer.update"]["total"]
94+
update_count = update["TorchPPOOptimizer.update"]["count"]
95+
else:
96+
update_total = update["SACTrainer._update_policy"]["total"]
97+
update_count = update["SACTrainer._update_policy"]["count"]
9198
evaluate_total = evaluate["TorchPolicy.evaluate"]["total"]
92-
update_count = update["TorchPPOOptimizer.update"]["count"]
9399
evaluate_count = evaluate["TorchPolicy.evaluate"]["count"]
94100
else:
95-
update_total = update["TFPPOOptimizer.update"]["total"]
101+
if algo == "ppo":
102+
update_total = update["TFPPOOptimizer.update"]["total"]
103+
update_count = update["TFPPOOptimizer.update"]["count"]
104+
else:
105+
update_total = update["SACTrainer._update_policy"]["total"]
106+
update_count = update["SACTrainer._update_policy"]["count"]
96107
evaluate_total = evaluate["NNPolicy.evaluate"]["total"]
97-
update_count = update["TFPPOOptimizer.update"]["count"]
98108
evaluate_count = evaluate["NNPolicy.evaluate"]["count"]
99109
# todo: do total / count
100110
return (
101111
name,
102112
str(steps),
103113
str(use_torch),
114+
algo,
104115
str(num_torch_threads),
105116
str(num_envs),
106117
str(use_gpu),
@@ -133,28 +144,41 @@ def main():
133144
action="store_true",
134145
help="If true, will only do 3dball",
135146
)
147+
parser.add_argument(
148+
"--sac",
149+
default=False,
150+
action="store_true",
151+
help="If true, will run sac instead of ppo",
152+
)
136153
args = parser.parse_args()
137154

138155
if args.gpu:
139156
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
140157
else:
141158
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
142159

160+
algo = "ppo"
161+
if args.sac:
162+
algo = "sac"
163+
143164
envs_config_tuples = [
144165
("3DBall", "3DBall"),
145166
("GridWorld", "GridWorld"),
146167
("PushBlock", "PushBlock"),
147-
("Hallway", "Hallway"),
148168
("CrawlerStaticTarget", "CrawlerStatic"),
149-
("VisualHallway", "VisualHallway"),
150169
]
170+
if algo == "ppo":
171+
envs_config_tuples += [("Hallway", "Hallway"),
172+
("VisualHallway", "VisualHallway")]
151173
if args.ball:
152174
envs_config_tuples = [("3DBall", "3DBall")]
153175

176+
154177
labels = (
155178
"name",
156179
"steps",
157180
"use_torch",
181+
"algorithm",
158182
"num_torch_threads",
159183
"num_envs",
160184
"use_gpu",
@@ -170,7 +194,7 @@ def main():
170194
results = []
171195
results.append(labels)
172196
f = open(
173-
f"result_data_steps_{args.steps}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt",
197+
f"result_data_steps_{args.steps}_algo_{algo}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt",
174198
"w",
175199
)
176200
f.write(" ".join(labels) + "\n")
@@ -180,6 +204,7 @@ def main():
180204
name=env_config[0],
181205
steps=args.steps,
182206
use_torch=True,
207+
algo=algo,
183208
num_torch_threads=1,
184209
use_gpu=args.gpu,
185210
num_envs=args.num_envs,
@@ -193,6 +218,7 @@ def main():
193218
name=env_config[0],
194219
steps=args.steps,
195220
use_torch=True,
221+
algo=algo,
196222
num_torch_threads=8,
197223
use_gpu=args.gpu,
198224
num_envs=args.num_envs,
@@ -205,6 +231,7 @@ def main():
205231
name=env_config[0],
206232
steps=args.steps,
207233
use_torch=False,
234+
algo=algo,
208235
num_torch_threads=1,
209236
use_gpu=args.gpu,
210237
num_envs=args.num_envs,

ml-agents/mlagents/trainers/distributions_torch.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ def __init__(self, mean, std):
1313
self.std = std
1414

1515
def sample(self):
16-
return self.mean + torch.randn_like(self.mean) * self.std
16+
sample = self.mean + torch.randn_like(self.mean) * self.std
17+
return sample
1718

1819
def log_prob(self, value):
1920
var = self.std ** 2
20-
log_scale = self.std.log()
21+
log_scale = torch.log(self.std + EPSILON)
2122
return (
22-
-((value - self.mean) ** 2) / (2 * var)
23+
-((value - self.mean) ** 2) / (2 * var + EPSILON)
2324
- log_scale
2425
- math.log(math.sqrt(2 * math.pi))
2526
)
@@ -29,7 +30,28 @@ def pdf(self, value):
2930
return torch.exp(log_prob)
3031

3132
def entropy(self):
32-
return torch.log(2 * math.pi * math.e * self.std)
33+
return torch.log(2 * math.pi * math.e * self.std + EPSILON)
34+
35+
36+
class TanhGaussianDistInstance(GaussianDistInstance):
37+
def __init__(self, mean, std):
38+
super().__init__(mean, std)
39+
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1)
40+
41+
def sample(self):
42+
unsquashed_sample = super().sample()
43+
squashed = self.transform(unsquashed_sample)
44+
return squashed
45+
46+
def _inverse_tanh(self, value):
47+
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
48+
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)
49+
50+
def log_prob(self, value):
51+
unsquashed = self.transform.inv(value)
52+
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
53+
unsquashed, value
54+
)
3355

3456

3557
class CategoricalDistInstance(nn.Module):
@@ -47,15 +69,26 @@ def pdf(self, value):
4769
def log_prob(self, value):
4870
return torch.log(self.pdf(value))
4971

72+
def all_log_prob(self):
73+
return torch.log(self.probs)
74+
5075
def entropy(self):
5176
return torch.sum(self.probs * torch.log(self.probs), dim=-1)
5277

5378

5479
class GaussianDistribution(nn.Module):
55-
def __init__(self, hidden_size, num_outputs, conditional_sigma=False, **kwargs):
80+
def __init__(
81+
self,
82+
hidden_size,
83+
num_outputs,
84+
conditional_sigma=False,
85+
tanh_squash=False,
86+
**kwargs
87+
):
5688
super(GaussianDistribution, self).__init__(**kwargs)
5789
self.conditional_sigma = conditional_sigma
5890
self.mu = nn.Linear(hidden_size, num_outputs)
91+
self.tanh_squash = tanh_squash
5992
nn.init.xavier_uniform_(self.mu.weight, gain=0.01)
6093
if conditional_sigma:
6194
self.log_sigma = nn.Linear(hidden_size, num_outputs)
@@ -68,10 +101,13 @@ def __init__(self, hidden_size, num_outputs, conditional_sigma=False, **kwargs):
68101
def forward(self, inputs):
69102
mu = self.mu(inputs)
70103
if self.conditional_sigma:
71-
log_sigma = self.log_sigma(inputs)
104+
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
72105
else:
73106
log_sigma = self.log_sigma
74-
return [GaussianDistInstance(mu, torch.exp(log_sigma))]
107+
if self.tanh_squash:
108+
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
109+
else:
110+
return [GaussianDistInstance(mu, torch.exp(log_sigma))]
75111

76112

77113
class MultiCategoricalDistribution(nn.Module):

0 commit comments

Comments
 (0)