Skip to content

Commit

Permalink
add hooks in mlpextractor instead
Browse files Browse the repository at this point in the history
  • Loading branch information
taufeeque9 committed May 22, 2024
1 parent 72361e7 commit f355b45
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
6 changes: 0 additions & 6 deletions stable_baselines3/common/recurrent/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
NatureCNN,
)
from stable_baselines3.common.type_aliases import Schedule, TorchGymObs, non_null
from transformer_lens.hook_points import HookPoint
from mamba_lens import InputDependentHookedRootModule


Expand Down Expand Up @@ -695,9 +694,6 @@ def __init__(
optimizer_kwargs=optimizer_kwargs,
)

self.hook_latent_pi = HookPoint()
self.hook_latent_vf = HookPoint()

# setup hook points
super().setup()

Expand Down Expand Up @@ -736,9 +732,7 @@ def forward( # type: ignore[override]
"""
latents, state = self._recurrent_extract_features(obs, state, episode_starts)
latent_pi = self.mlp_extractor.forward_actor(latents)
latent_pi = self.hook_latent_pi(latent_pi)
latent_vf = self.mlp_extractor.forward_critic(latents)
latent_vf = self.hook_latent_vf(latent_vf)

# Evaluate the values for the given observations
values = self.value_net(latent_vf)
Expand Down
24 changes: 14 additions & 10 deletions stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections import OrderedDict
from typing import Dict, List, Tuple, Type, Union

import gymnasium as gym
import torch as th
from gymnasium import spaces
from torch import nn
from transformer_lens.hook_points import HookPoint

from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3.common.type_aliases import TensorDict
Expand Down Expand Up @@ -182,8 +184,8 @@ def __init__(
) -> None:
super().__init__()
device = get_device(device)
policy_net: List[nn.Module] = []
value_net: List[nn.Module] = []
policy_net: OrderedDict[str, nn.Module] = OrderedDict()
value_net: OrderedDict[str, nn.Module] = OrderedDict()
last_layer_dim_pi = feature_dim
last_layer_dim_vf = feature_dim

Expand All @@ -195,14 +197,16 @@ def __init__(
else:
pi_layers_dims = vf_layers_dims = net_arch
# Iterate through the policy layers and build the policy net
for curr_layer_dim in pi_layers_dims:
policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim))
policy_net.append(activation_fn())
for i, curr_layer_dim in enumerate(pi_layers_dims):
policy_net[f"fc{i}"] = nn.Linear(last_layer_dim_pi, curr_layer_dim)
policy_net[f"activation{i}"] = activation_fn()
policy_net[f"hook_fc{i}"] = HookPoint()
last_layer_dim_pi = curr_layer_dim
# Iterate through the value layers and build the value net
for curr_layer_dim in vf_layers_dims:
value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim))
value_net.append(activation_fn())
for i, curr_layer_dim in enumerate(vf_layers_dims):
value_net[f"fc{i}"] = nn.Linear(last_layer_dim_vf, curr_layer_dim)
value_net[f"activation{i}"] = activation_fn()
value_net[f"hook_fc{i}"] = HookPoint()
last_layer_dim_vf = curr_layer_dim

# Save dim, used to create the distributions
Expand All @@ -211,8 +215,8 @@ def __init__(

# Create networks
# If the list of layers is empty, the network will just act as an Identity module
self.policy_net = nn.Sequential(*policy_net).to(device)
self.value_net = nn.Sequential(*value_net).to(device)
self.policy_net = nn.Sequential(policy_net).to(device)
self.value_net = nn.Sequential(value_net).to(device)

def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
Expand Down

0 comments on commit f355b45

Please sign in to comment.