Skip to content

Commit a607f07

Browse files
authored
[IBR-2069] Modify hidden activation function (#317)
1 parent 3f0f095 commit a607f07

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

configs/lunarlander_continuous_v2/ppo.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ learner_cfg:
2727
type: "GaussianDist"
2828
configs:
2929
hidden_sizes: [256, 256]
30+
hidden_activation: "tanh"
3031
output_activation: "identity"
3132
fixed_logstd: True
3233
critic:

rl_algorithms/common/helper_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ def identity(x: torch.Tensor) -> torch.Tensor:
2828
return x
2929

3030

31+
def relu(x: torch.Tensor) -> torch.Tensor:
32+
"""Return torch.relu(x)"""
33+
return torch.relu(x)
34+
35+
3136
def soft_update(local: nn.Module, target: nn.Module, tau: float):
3237
"""Soft-update: target = tau*local + (1-tau)*target."""
3338
for t_param, l_param in zip(target.parameters(), local.parameters()):

rl_algorithms/common/networks/heads.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ def __init__(
6161
self.hidden_sizes = configs.hidden_sizes
6262
self.input_size = configs.input_size
6363
self.output_size = configs.output_size
64-
self.hidden_activation = hidden_activation
64+
self.hidden_activation = (
65+
getattr(helper_functions, configs.hidden_activation)
66+
if "hidden_activation" in configs.keys()
67+
else hidden_activation
68+
)
6569
self.output_activation = getattr(helper_functions, configs.output_activation)
6670
self.linear_layer = linear_layer
6771
self.use_output_layer = use_output_layer

0 commit comments

Comments
 (0)