Skip to content

Commit

Permalink
Vision backbone improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 19, 2024
1 parent 0dc5b43 commit eeae30e
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 29 deletions.
1 change: 0 additions & 1 deletion rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def get_value_layer(self):

def build(self, config):
obs_shape = config['input_shape']
print(f"obs_shape: {obs_shape}")
normalize_value = config.get('normalize_value', False)
normalize_input = config.get('normalize_input', False)
value_size = config.get('value_size', 1)
Expand Down
61 changes: 35 additions & 26 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,6 @@ def forward(self, obs_dict):
out = self.flatten_act(out)

if self.has_rnn:
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)

out_in = out
Expand Down Expand Up @@ -1103,12 +1102,11 @@ def __init__(self, params, **kwargs):
if self.permute_input:
input_shape = torch_ext.shape_whc_to_cwh(input_shape)

self.cnn = self._build_backbone(input_shape, self.params['backbone'])
cnn_output_size = self.cnn_output_size
self.cnn, self.cnn_output_size = self._build_backbone(input_shape, params['backbone'])

mlp_input_size = cnn_output_size + self.proprio_size
mlp_input_size = self.cnn_output_size + self.proprio_size
if len(self.units) == 0:
out_size = cnn_output_size
out_size = self.cnn_output_size
else:
out_size = self.units[-1]

Expand Down Expand Up @@ -1153,9 +1151,6 @@ def __init__(self, params, **kwargs):

mlp_init = self.init_factory.create(**self.initializer)

# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
for m in self.mlp:
if isinstance(m, nn.Linear):
mlp_init(m.weight)
Expand Down Expand Up @@ -1271,36 +1266,50 @@ def load(self, params):
self.require_last_actions = params.get('require_last_actions')

def _build_backbone(self, input_shape, backbone_params):
print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
print(backbone_params)
backbone_type = backbone_params['type']
pretrained = backbone_params.get('pretrained', False)

if backbone_type == 'resnet18':
model = models.resnet18(pretrained=pretrained)
backbone = models.resnet18(pretrained=pretrained, zero_init_residual=True) # norm_layer=nn.LayerNorm
# Modify the first convolution layer to match input shape if needed
if input_shape[0] != 3:
model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False)
backbone.maxpool = nn.Identity()
# if input_shape[0] != 3:
# model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)
# Remove the fully connected layer
self.cnn_output_size = model.fc.in_features
model = nn.Sequential(*list(model.children())[:-1])
backbone_output_size = backbone.fc.in_features
print('backbone_output_size: ', backbone_output_size)
backbone = nn.Sequential(*list(backbone.children())[:-1])
elif backbone_type == 'convnext_tiny':
model = create_model('convnext_tiny', pretrained=pretrained)
backbone = create_model('convnext_tiny', pretrained=pretrained)
# Modify the first convolution layer to match input shape if needed
#backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False)
# Remove the fully connected layer
self.cnn_output_size = model.head.fc.in_features
model = nn.Sequential(*list(model.children())[:-1])
# elif backbone_type == 'vit_tiny_patch16_224':
# model = create_model('vit_tiny_patch16_224', pretrained=pretrained)
# # ViT outputs a single token, so no need to remove layers
# self.cnn_output
backbone_output_size = backbone.head.fc.in_features

backbone = nn.Sequential(*list(backbone.children())[:-1])
elif backbone_type == 'vit_tiny_patch16_224':
backbone = create_model('vit_tiny_patch16_224', pretrained=pretrained)
# # ViT outputs a single token, so no need to remove layers
# backbone = models.vit_small_patch16_224(pretrained=pretrained)
backbone_output_size = backbone.heads.head.in_features
backbone.heads.head = nn.Identity()
else:
raise ValueError(f'Unknown backbone type: {backbone_type}')

return model
# Optionally freeze the follow-up layers, leaving the first convolutional layer unfrozen
if backbone_params.get('freeze', False):
for name, param in backbone.named_parameters():
if 'conv1' not in name: # Ensure the first conv layer is not frozen
param.requires_grad = False

def build(self, name, **kwargs):
net = VisionBackboneBuilder.Network(self.params, **kwargs)
return net
return backbone, backbone_output_size

def build(self, name, **kwargs):
print("Building Network")
print(self.params)
net = VisionBackboneBuilder.Network(self.params, **kwargs)
return net


class DiagGaussianActor(NetworkBuilder.BaseNetwork):
Expand Down
97 changes: 97 additions & 0 deletions rl_games/configs/atari/ppo_pong_envpool_backbone.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
params:
algo:
name: a2c_discrete

model:
name: discrete_a2c

network:
# name: resnet_actor_critic
# require_rewards: True
# require_last_actions: True
# separate: False
# value_shape: 1
name: e2e_vision_actor_critic
separate: False
value_shape: 1
space:
discrete:

# cnn:
# permute_input: False
# conv_depths: [16, 32, 32]
# activation: relu
# initializer:
# name: default
# regularizer:
# name: 'None'

backbone:
type: resnet18 #vit_tiny_patch16_224 #convnext_tiny #resnet18
pretrained: False
permute_input: True
freeze: False

args:
zero_init_residual: False
norm_layer: None

mlp:
units: [512]
activation: relu
regularizer:
name: 'None'
initializer:
name: default

# rnn:
# name: lstm
# units: 256
# layers: 1
config:
name: pong_resnet18_nopretrained_novaluenorm_weightdecay_nomaxpool
env_name: envpool
reward_shaper:
min_val: -1
max_val: 1

mixed_precision: False
normalize_input: False
normalize_value: False
normalize_advantage: True
gamma: 0.995
tau: 0.95
learning_rate: 1e-4

score_to_win: 100000
grad_norm: 1.5
entropy_coef: 0.01
truncate_grads: True

e_clip: 0.2
clip_value: True
save_best_after: 25
save_frequency: 50
num_actors: 64
horizon_length: 128
minibatch_size: 2048
mini_epochs: 4
critic_coef: 1
lr_schedule: None
kl_threshold: 0.01
use_diagnostics: True
seq_length: 8
max_epochs: 1000
#weight_decay: 0.001

env_config:
env_name: Pong-v5
has_lives: False
use_dict_obs_space: False #True

player:
render: True
games_num: 10
n_game_life: 1
deterministic: True

2 changes: 0 additions & 2 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def forward(self, obs_dict):
out = torch.cat([out, proprio], dim=1)

if self.has_rnn:
# TODO: Double check, it's not lways present!!!
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)

out_in = out
Expand Down

0 comments on commit eeae30e

Please sign in to comment.