@@ -64,6 +64,7 @@ def __init__(
64
64
act_size : List [int ],
65
65
reparameterize : bool = False ,
66
66
tanh_squash : bool = False ,
67
+ condition_sigma : bool = True ,
67
68
log_sigma_min : float = - 20 ,
68
69
log_sigma_max : float = 2 ,
69
70
):
@@ -79,7 +80,11 @@ def __init__(
79
80
:param log_sigma_max: Maximum log standard deviation to clip by.
80
81
"""
81
82
encoded = self ._create_mu_log_sigma (
82
- logits , act_size , log_sigma_min , log_sigma_max
83
+ logits ,
84
+ act_size ,
85
+ log_sigma_min ,
86
+ log_sigma_max ,
87
+ condition_sigma = condition_sigma ,
83
88
)
84
89
self ._sampled_policy = self ._create_sampled_policy (encoded )
85
90
if not reparameterize :
@@ -101,6 +106,7 @@ def _create_mu_log_sigma(
101
106
act_size : List [int ],
102
107
log_sigma_min : float ,
103
108
log_sigma_max : float ,
109
+ condition_sigma : bool ,
104
110
) -> "GaussianDistribution.MuSigmaTensors" :
105
111
106
112
mu = tf .layers .dense (
@@ -112,14 +118,22 @@ def _create_mu_log_sigma(
112
118
reuse = tf .AUTO_REUSE ,
113
119
)
114
120
115
- # Policy-dependent log_sigma_sq
116
- log_sigma = tf .layers .dense (
117
- logits ,
118
- act_size [0 ],
119
- activation = None ,
120
- name = "log_std" ,
121
- kernel_initializer = ModelUtils .scaled_init (0.01 ),
122
- )
121
+ if condition_sigma :
122
+ # Policy-dependent log_sigma_sq
123
+ log_sigma = tf .layers .dense (
124
+ logits ,
125
+ act_size [0 ],
126
+ activation = None ,
127
+ name = "log_std" ,
128
+ kernel_initializer = ModelUtils .scaled_init (0.01 ),
129
+ )
130
+ else :
131
+ log_sigma = tf .get_variable (
132
+ "log_std" ,
133
+ [act_size [0 ]],
134
+ dtype = tf .float32 ,
135
+ initializer = tf .zeros_initializer (),
136
+ )
123
137
log_sigma = tf .clip_by_value (log_sigma , log_sigma_min , log_sigma_max )
124
138
sigma = tf .exp (log_sigma )
125
139
return self .MuSigmaTensors (mu , log_sigma , sigma )
@@ -155,8 +169,8 @@ def _do_squash_correction_for_tanh(self, probs, squashed_policy):
155
169
"""
156
170
Adjust probabilities for squashed sample before output
157
171
"""
158
- probs -= tf .log (1 - squashed_policy ** 2 + EPSILON )
159
- return probs
172
+ adjusted_probs = probs - tf .log (1 - squashed_policy ** 2 + EPSILON )
173
+ return adjusted_probs
160
174
161
175
@property
162
176
def total_log_probs (self ) -> tf .Tensor :
0 commit comments