From eeae30e408cb0c4936d0ab6ee29d24be1f41dae8 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Sun, 18 Aug 2024 19:51:04 -0700 Subject: [PATCH] Vision backbone improvements. --- rl_games/algos_torch/models.py | 1 - rl_games/algos_torch/network_builder.py | 61 +++++++----- .../atari/ppo_pong_envpool_backbone.yaml | 97 +++++++++++++++++++ rl_games/networks/vision_networks.py | 2 - 4 files changed, 132 insertions(+), 29 deletions(-) create mode 100644 rl_games/configs/atari/ppo_pong_envpool_backbone.yaml diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index fb518c9c..4a15183c 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -26,7 +26,6 @@ def get_value_layer(self): def build(self, config): obs_shape = config['input_shape'] - print(f"obs_shape: {obs_shape}") normalize_value = config.get('normalize_value', False) normalize_input = config.get('normalize_input', False) value_size = config.get('value_size', 1) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 29968f40..6fb00ff0 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -737,7 +737,6 @@ def forward(self, obs_dict): out = self.flatten_act(out) if self.has_rnn: - #seq_length = obs_dict['seq_length'] seq_length = obs_dict.get('seq_length', 1) out_in = out @@ -1103,12 +1102,11 @@ def __init__(self, params, **kwargs): if self.permute_input: input_shape = torch_ext.shape_whc_to_cwh(input_shape) - self.cnn = self._build_backbone(input_shape, self.params['backbone']) - cnn_output_size = self.cnn_output_size + self.cnn, self.cnn_output_size = self._build_backbone(input_shape, params['backbone']) - mlp_input_size = cnn_output_size + self.proprio_size + mlp_input_size = self.cnn_output_size + self.proprio_size if len(self.units) == 0: - out_size = cnn_output_size + out_size = self.cnn_output_size else: out_size = self.units[-1] @@ -1153,9 +1151,6 @@ def __init__(self, params, **kwargs): mlp_init = self.init_factory.create(**self.initializer) - # for m in self.modules(): - # if isinstance(m, nn.Conv2d): - # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') for m in self.mlp: if isinstance(m, nn.Linear): mlp_init(m.weight) @@ -1271,36 +1266,50 @@ def load(self, params): self.require_last_actions = params.get('require_last_actions') def _build_backbone(self, input_shape, backbone_params): - print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") - print(backbone_params) backbone_type = backbone_params['type'] pretrained = backbone_params.get('pretrained', False) if backbone_type == 'resnet18': - model = models.resnet18(pretrained=pretrained) + backbone = models.resnet18(pretrained=pretrained, zero_init_residual=True) # norm_layer=nn.LayerNorm # Modify the first convolution layer to match input shape if needed - if input_shape[0] != 3: - model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) + backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) + backbone.maxpool = nn.Identity() + # if input_shape[0] != 3: + # model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) # Remove the fully connected layer - self.cnn_output_size = model.fc.in_features - model = nn.Sequential(*list(model.children())[:-1]) + backbone_output_size = backbone.fc.in_features + print('backbone_output_size: ', backbone_output_size) + backbone = nn.Sequential(*list(backbone.children())[:-1]) elif backbone_type == 'convnext_tiny': - model = create_model('convnext_tiny', pretrained=pretrained) + backbone = create_model('convnext_tiny', pretrained=pretrained) + # Modify the first convolution layer to match input shape if needed + #backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) # Remove the fully connected layer - self.cnn_output_size = model.head.fc.in_features - model = nn.Sequential(*list(model.children())[:-1]) - # elif backbone_type == 'vit_tiny_patch16_224': - # model = create_model('vit_tiny_patch16_224', pretrained=pretrained) - # # ViT outputs a single token, so no need to remove layers - # self.cnn_output + backbone_output_size = backbone.head.fc.in_features + + backbone = nn.Sequential(*list(backbone.children())[:-1]) + elif backbone_type == 'vit_tiny_patch16_224': + backbone = create_model('vit_tiny_patch16_224', pretrained=pretrained) + # # ViT outputs a single token, so no need to remove layers + # backbone = models.vit_small_patch16_224(pretrained=pretrained) + backbone_output_size = backbone.heads.head.in_features + backbone.heads.head = nn.Identity() else: raise ValueError(f'Unknown backbone type: {backbone_type}') - return model + # Optionally freeze the follow-up layers, leaving the first convolutional layer unfrozen + if backbone_params.get('freeze', False): + for name, param in backbone.named_parameters(): + if 'conv1' not in name: # Ensure the first conv layer is not frozen + param.requires_grad = False - def build(self, name, **kwargs): - net = VisionBackboneBuilder.Network(self.params, **kwargs) - return net + return backbone, backbone_output_size + + def build(self, name, **kwargs): + print("Building Network") + print(self.params) + net = VisionBackboneBuilder.Network(self.params, **kwargs) + return net class DiagGaussianActor(NetworkBuilder.BaseNetwork): diff --git a/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml b/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml new file mode 100644 index 00000000..6ae0f5c1 --- /dev/null +++ b/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml @@ -0,0 +1,97 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + # name: resnet_actor_critic + # require_rewards: True + # require_last_actions: True + # separate: False + # value_shape: 1 + name: e2e_vision_actor_critic + separate: False + value_shape: 1 + space: + discrete: + + # cnn: + # permute_input: False + # conv_depths: [16, 32, 32] + # activation: relu + # initializer: + # name: default + # regularizer: + # name: 'None' + + backbone: + type: resnet18 #vit_tiny_patch16_224 #convnext_tiny #resnet18 + pretrained: False + permute_input: True + freeze: False + + args: + zero_init_residual: False + norm_layer: None + + mlp: + units: [512] + activation: relu + regularizer: + name: 'None' + initializer: + name: default + + # rnn: + # name: lstm + # units: 256 + # layers: 1 + config: + name: pong_resnet18_nopretrained_novaluenorm_weightdecay_nomaxpool + env_name: envpool + reward_shaper: + min_val: -1 + max_val: 1 + + mixed_precision: False + normalize_input: False + normalize_value: False + normalize_advantage: True + gamma: 0.995 + tau: 0.95 + learning_rate: 1e-4 + + score_to_win: 100000 + grad_norm: 1.5 + entropy_coef: 0.01 + truncate_grads: True + + e_clip: 0.2 + clip_value: True + save_best_after: 25 + save_frequency: 50 + num_actors: 64 + horizon_length: 128 + minibatch_size: 2048 + mini_epochs: 4 + critic_coef: 1 + lr_schedule: None + kl_threshold: 0.01 + use_diagnostics: True + seq_length: 8 + max_epochs: 1000 + #weight_decay: 0.001 + + env_config: + env_name: Pong-v5 + has_lives: False + use_dict_obs_space: False #True + + player: + render: True + games_num: 10 + n_game_life: 1 + deterministic: True + diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index fc4a7772..ba98645f 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -122,8 +122,6 @@ def forward(self, obs_dict): out = torch.cat([out, proprio], dim=1) if self.has_rnn: - # TODO: Double check, it's not lways present!!! - #seq_length = obs_dict['seq_length'] seq_length = obs_dict.get('seq_length', 1) out_in = out