1
+ import math
1
2
from typing import Dict , Tuple
2
3
3
4
import gymnasium as gym
4
- import numpy as np
5
5
import torch
6
6
import torch .nn .functional as F
7
7
from rl .loss import entropy_loss , policy_loss , value_loss
@@ -24,7 +24,8 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_
24
24
raise ValueError ("Unrecognized activation function: `act_fun` must be either `relu` or `tanh`" )
25
25
self .critic = torch .nn .Sequential (
26
26
layer_init (
27
- torch .nn .Linear (np .array (envs .single_observation_space .shape ).prod (), 64 ), ortho_init = ortho_init
27
+ torch .nn .Linear (math .prod (envs .single_observation_space .shape ), 64 ),
28
+ ortho_init = ortho_init ,
28
29
),
29
30
act_fun ,
30
31
layer_init (torch .nn .Linear (64 , 64 ), ortho_init = ortho_init ),
@@ -33,7 +34,8 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_
33
34
)
34
35
self .actor = torch .nn .Sequential (
35
36
layer_init (
36
- torch .nn .Linear (np .array (envs .single_observation_space .shape ).prod (), 64 ), ortho_init = ortho_init
37
+ torch .nn .Linear (math .prod (envs .single_observation_space .shape ), 64 ),
38
+ ortho_init = ortho_init ,
37
39
),
38
40
act_fun ,
39
41
layer_init (torch .nn .Linear (64 , 64 ), ortho_init = ortho_init ),
@@ -81,10 +83,10 @@ def estimate_returns_and_advantages(
81
83
lastgaelam = 0
82
84
for t in reversed (range (num_steps )):
83
85
if t == num_steps - 1 :
84
- nextnonterminal = 1.0 - next_done
86
+ nextnonterminal = torch . logical_not ( next_done )
85
87
nextvalues = next_value
86
88
else :
87
- nextnonterminal = 1.0 - dones [t + 1 ]
89
+ nextnonterminal = torch . logical_not ( dones [t + 1 ])
88
90
nextvalues = values [t + 1 ]
89
91
delta = rewards [t ] + gamma * nextvalues * nextnonterminal - values [t ]
90
92
advantages [t ] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
@@ -119,7 +121,8 @@ def __init__(
119
121
self .normalize_advantages = normalize_advantages
120
122
self .critic = torch .nn .Sequential (
121
123
layer_init (
122
- torch .nn .Linear (np .array (envs .single_observation_space .shape ).prod (), 64 ), ortho_init = ortho_init
124
+ torch .nn .Linear (math .prod (envs .single_observation_space .shape ), 64 ),
125
+ ortho_init = ortho_init ,
123
126
),
124
127
act_fun ,
125
128
layer_init (torch .nn .Linear (64 , 64 ), ortho_init = ortho_init ),
@@ -128,7 +131,8 @@ def __init__(
128
131
)
129
132
self .actor = torch .nn .Sequential (
130
133
layer_init (
131
- torch .nn .Linear (np .array (envs .single_observation_space .shape ).prod (), 64 ), ortho_init = ortho_init
134
+ torch .nn .Linear (math .prod (envs .single_observation_space .shape ), 64 ),
135
+ ortho_init = ortho_init ,
132
136
),
133
137
act_fun ,
134
138
layer_init (torch .nn .Linear (64 , 64 ), ortho_init = ortho_init ),
@@ -179,10 +183,10 @@ def estimate_returns_and_advantages(
179
183
lastgaelam = 0
180
184
for t in reversed (range (num_steps )):
181
185
if t == num_steps - 1 :
182
- nextnonterminal = 1.0 - next_done
186
+ nextnonterminal = torch . logical_not ( next_done )
183
187
nextvalues = next_value
184
188
else :
185
- nextnonterminal = 1.0 - dones [t + 1 ]
189
+ nextnonterminal = torch . logical_not ( dones [t + 1 ])
186
190
nextvalues = values [t + 1 ]
187
191
delta = rewards [t ] + gamma * nextvalues * nextnonterminal - values [t ]
188
192
advantages [t ] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
0 commit comments