diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index e0812d70..89acc1ee 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -35,7 +35,7 @@ def __init__(self, base_name, params): 'actions_num' : self.actions_num, 'input_shape' : obs_shape, 'num_seqs' : self.num_actors * self.num_agents, - 'value_size': self.env_info.get('value_size',1), + 'value_size': self.env_info.get('value_size', 1), 'normalize_value' : self.normalize_value, 'normalize_input': self.normalize_input, } @@ -144,6 +144,14 @@ def calc_gradients(self, input_dict): 'obs' : obs_batch, } + # print("TEST") + # print("----------------") + # for key in input_dict: + # print(key) + + # if "proprio" in input_dict: + # batch_dict['proprio'] = input_dict['proprio'] + rnn_masks = None if self.is_rnn: rnn_masks = input_dict['rnn_masks'] diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index c2045c5e..58378063 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -19,7 +19,10 @@ def __init__(self): self.network_factory.set_builders(NETWORK_REGISTRY) self.network_factory.register_builder('actor_critic', lambda **kwargs: network_builder.A2CBuilder()) self.network_factory.register_builder('resnet_actor_critic', - lambda **kwargs: network_builder.A2CResnetBuilder()) + lambda **kwargs: network_builder.A2CResnetBuilder()) + self.network_factory.register_builder('vision_actor_critic', + lambda **kwargs: network_builder.A2CVisionBuilder()) + self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder()) self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder()) diff --git a/rl_games/algos_torch/moving_mean_std.py b/rl_games/algos_torch/moving_mean_std.py index 363da8f4..fa8c9360 100644 --- a/rl_games/algos_torch/moving_mean_std.py +++ b/rl_games/algos_torch/moving_mean_std.py @@ -76,7 +76,6 @@ def _get_stats(self): else: raise NotImplementedError(self.impl) - def _update_stats(self, x): m = self.decay if self.impl == 'off': diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index e5d625c0..699ec7db 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -263,8 +263,8 @@ def __init__(self, params, **kwargs): mlp_args = { 'input_size' : mlp_input_size, - 'units' : self.units, - 'activation' : self.activation, + 'units' : self.units, + 'activation' : self.activation, 'norm_func_name' : self.normalization, 'dense_func' : torch.nn.Linear, 'd2rl' : self.is_d2rl, @@ -417,7 +417,7 @@ def forward(self, obs_dict): else: out = obs out = self.actor_cnn(out) - out = out.flatten(1) + out = out.flatten(1) if self.has_rnn: seq_length = obs_dict.get('seq_length', 1) @@ -664,12 +664,12 @@ def __init__(self, params, **kwargs): rnn_in_size += actions_num self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) - #self.layer_norm = torch.nn.LayerNorm(self.rnn_units) + self.layer_norm = torch.nn.LayerNorm(self.rnn_units) mlp_args = { 'input_size' : mlp_input_size, - 'units' :self.units, - 'activation' : self.activation, + 'units' :self.units, + 'activation' : self.activation, 'norm_func_name' : self.normalization, 'dense_func' : torch.nn.Linear } @@ -701,7 +701,7 @@ def __init__(self, params, **kwargs): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) for m in self.mlp: - if isinstance(m, nn.Linear): + if isinstance(m, nn.Linear): mlp_init(m.weight) if self.is_discrete: @@ -733,9 +733,242 @@ def forward(self, obs_dict): out = obs out = self.cnn(out) - out = out.flatten(1) + out = out.flatten(1) + out = self.flatten_act(out) + + if self.has_rnn: + #seq_length = obs_dict['seq_length'] + seq_length = obs_dict.get('seq_length', 1) + + out_in = out + if not self.is_rnn_before_mlp: + out_in = out + out = self.mlp(out) + + obs_list = [out] + if self.require_rewards: + obs_list.append(reward.unsqueeze(1)) + if self.require_last_actions: + obs_list.append(last_action) + out = torch.cat(obs_list, dim=1) + batch_size = out.size()[0] + num_seqs = batch_size // seq_length + out = out.reshape(num_seqs, seq_length, -1) + + if len(states) == 1: + states = states[0] + + out = out.transpose(0, 1) + if dones is not None: + dones = dones.reshape(num_seqs, seq_length, -1) + dones = dones.transpose(0, 1) + out, states = self.rnn(out, states, dones, bptt_len) + out = out.transpose(0, 1) + out = out.contiguous().reshape(out.size()[0] * out.size()[1], -1) + + if self.rnn_ln: + out = self.layer_norm(out) + if self.is_rnn_before_mlp: + out = self.mlp(out) + if type(states) is not tuple: + states = (states,) + else: + out = self.mlp(out) + + value = self.value_act(self.value(out)) + + if self.is_discrete: + logits = self.logits(out) + return logits, value, states + + if self.is_continuous: + mu = self.mu_act(self.mu(out)) + if self.fixed_sigma: + sigma = self.sigma_act(self.sigma) + else: + sigma = self.sigma_act(self.sigma(out)) + return mu, mu*0 + sigma, value, states + + def load(self, params): + self.separate = False + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous' in params['space'] + self.is_multi_discrete = 'multi_discrete'in params['space'] + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + + if self.is_continuous: + self.space_config = params['space']['continuous'] + self.fixed_sigma = self.space_config['fixed_sigma'] + elif self.is_discrete: + self.space_config = params['space']['discrete'] + elif self.is_multi_discrete: + self.space_config = params['space']['multi_discrete'] + + self.has_rnn = 'rnn' in params + if self.has_rnn: + self.rnn_units = params['rnn']['units'] + self.rnn_layers = params['rnn']['layers'] + self.rnn_name = params['rnn']['name'] + self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) + self.rnn_ln = params['rnn'].get('layer_norm', False) + + self.has_cnn = True + self.permute_input = params['cnn'].get('permute_input', True) + self.conv_depths = params['cnn']['conv_depths'] + self.require_rewards = params.get('require_rewards') + self.require_last_actions = params.get('require_last_actions') + + def _build_impala(self, input_shape, depths): + in_channels = input_shape[0] + layers = nn.ModuleList() + for d in depths: + layers.append(ImpalaSequential(in_channels, d)) + in_channels = d + return nn.Sequential(*layers) + + def is_separate_critic(self): + return False + + def is_rnn(self): + return self.has_rnn + + def get_default_rnn_state(self): + num_layers = self.rnn_layers + if self.rnn_name == 'lstm': + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)), + torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + else: + return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) + + def build(self, name, **kwargs): + net = A2CResnetBuilder.Network(self.params, **kwargs) + return net + + +class A2CVisionBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + class Network(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + self.actions_num = actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + print('input_shape:', input_shape) + if type(input_shape) is dict: + input_shape = input_shape['observation'] + self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + self.value_size = kwargs.pop('value_size', 1) + + # TODO: add proprioception from config + # no normilization for proprioception for now + proprio_shape = kwargs.pop('proprio_shape', None) + self.proprio_size = 68 + + NetworkBuilder.BaseNetwork.__init__(self) + self.load(params) + if self.permute_input: + input_shape = torch_ext.shape_whc_to_cwh(input_shape) + + self.cnn = self._build_impala(input_shape, self.conv_depths) + cnn_output_size = self._calc_input_size(input_shape, self.cnn) + + mlp_input_size = cnn_output_size + self.proprio_size + if len(self.units) == 0: + out_size = cnn_output_size + else: + out_size = self.units[-1] + + if self.has_rnn: + if not self.is_rnn_before_mlp: + rnn_in_size = out_size + out_size = self.rnn_units + else: + rnn_in_size = mlp_input_size + mlp_input_size = self.rnn_units + + if self.require_rewards: + rnn_in_size += 1 + if self.require_last_actions: + rnn_in_size += actions_num + + self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) + self.layer_norm = torch.nn.LayerNorm(self.rnn_units) + + mlp_args = { + 'input_size' : mlp_input_size, + 'units' :self.units, + 'activation' : self.activation, + 'norm_func_name' : self.normalization, + 'dense_func' : torch.nn.Linear + } + + self.mlp = self._build_mlp(**mlp_args) + + self.value = self._build_value_layer(out_size, self.value_size) + self.value_act = self.activations_factory.create(self.value_activation) + self.flatten_act = self.activations_factory.create(self.activation) + + if self.is_discrete: + self.logits = torch.nn.Linear(out_size, actions_num) + if self.is_continuous: + self.mu = torch.nn.Linear(out_size, actions_num) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + mu_init = self.init_factory.create(**self.space_config['mu_init']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + sigma_init = self.init_factory.create(**self.space_config['sigma_init']) + + if self.fixed_sigma: + self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + else: + self.sigma = torch.nn.Linear(out_size, actions_num) + + mlp_init = self.init_factory.create(**self.initializer) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + #nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) + for m in self.mlp: + if isinstance(m, nn.Linear): + mlp_init(m.weight) + + if self.is_discrete: + mlp_init(self.logits.weight) + if self.is_continuous: + mu_init(self.mu.weight) + if self.fixed_sigma: + sigma_init(self.sigma) + else: + sigma_init(self.sigma.weight) + + mlp_init(self.value.weight) + + def forward(self, obs_dict): + # for key in obs_dict: + # print(key) + obs = obs_dict['obs']['camera'] + proprio = obs_dict['obs']['proprio'] + if self.permute_input: + obs = obs.permute((0, 3, 1, 2)) + + dones = obs_dict.get('dones', None) + bptt_len = obs_dict.get('bptt_len', 0) + states = obs_dict.get('rnn_states', None) + + out = obs + out = self.cnn(out) + out = out.flatten(1) out = self.flatten_act(out) + out = torch.cat([out, proprio], dim=1) + if self.has_rnn: #seq_length = obs_dict['seq_length'] seq_length = obs_dict.get('seq_length', 1) @@ -824,7 +1057,7 @@ def load(self, params): def _build_impala(self, input_shape, depths): in_channels = input_shape[0] - layers = nn.ModuleList() + layers = nn.ModuleList() for d in depths: layers.append(ImpalaSequential(in_channels, d)) in_channels = d @@ -845,7 +1078,7 @@ def get_default_rnn_state(self): return (torch.zeros((num_layers, self.num_seqs, self.rnn_units))) def build(self, name, **kwargs): - net = A2CResnetBuilder.Network(self.params, **kwargs) + net = A2CVisionBuilder.Network(self.params, **kwargs) return net @@ -923,9 +1156,9 @@ def __init__(self, params, **kwargs): self.load(params) actor_mlp_args = { - 'input_size' : obs_dim, - 'units' : self.units, - 'activation' : self.activation, + 'input_size' : obs_dim, + 'units' : self.units, + 'activation' : self.activation, 'norm_func_name' : self.normalization, 'dense_func' : torch.nn.Linear, 'd2rl' : self.is_d2rl, @@ -933,9 +1166,9 @@ def __init__(self, params, **kwargs): } critic_mlp_args = { - 'input_size' : obs_dim + action_dim, - 'units' : self.units, - 'activation' : self.activation, + 'input_size' : obs_dim + action_dim, + 'units' : self.units, + 'activation' : self.activation, 'norm_func_name' : self.normalization, 'dense_func' : torch.nn.Linear, 'd2rl' : self.is_d2rl, diff --git a/rl_games/algos_torch/torch_ext.py b/rl_games/algos_torch/torch_ext.py index baa6deb6..60e13757 100644 --- a/rl_games/algos_torch/torch_ext.py +++ b/rl_games/algos_torch/torch_ext.py @@ -208,6 +208,7 @@ def get_coord(x): x.size(0), 1, 1, 1).type_as(x) CoordConv2d.pool[key] = coord return CoordConv2d.pool[key] + def forward(self, x): return torch.nn.functional.conv2d(torch.cat([x, self.get_coord(x).type_as(x)], 1), self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) @@ -245,7 +246,6 @@ def forward(self, x): return self.gamma.expand_as(x) * (x - mean) / (std + self.eps) + self.beta.expand_as(x) - class DiscreteActionsEncoder(nn.Module): def __init__(self, actions_max, mlp_out, emb_size, num_agents, use_embedding): super().__init__() diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index e0c0cb28..47cd6cd9 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -416,6 +416,9 @@ def get_action_values(self, obs): 'rnn_states' : self.rnn_states } + # if 'proprio' in obs: + # input_dict['proprio'] = obs['proprio'] + with torch.no_grad(): res_dict = self.model(input_dict) if self.has_central_value: @@ -449,6 +452,9 @@ def get_values(self, obs): 'obs' : processed_obs, 'rnn_states' : self.rnn_states } + # if 'proprio' in obs: + # input_dict['proprio'] = obs['proprio'] + result = self.model(input_dict) value = result['values'] return value @@ -829,6 +835,8 @@ def play_steps_rnn(self): self.rnn_states = res_dict['rnn_states'] self.experience_buffer.update_data('obses', n, self.obs['obs']) + # if 'proprio' in self.obs: + # self.experience_buffer.update_data('proprio', n, self.obs['proprio']) self.experience_buffer.update_data('dones', n, self.dones.byte()) for k in update_list: @@ -1022,6 +1030,9 @@ def prepare_dataset(self, batch_dict): dataset_dict['rnn_states'] = rnn_states dataset_dict['rnn_masks'] = rnn_masks + # if 'proprio' in batch_dict: + # dataset_dict['proprio'] = batch_dict['proprio'] + if self.use_action_masks: dataset_dict['action_masks'] = batch_dict['action_masks'] diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index feea017c..7bef426e 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -325,10 +325,16 @@ def __init__(self, env_info, algo_info, device, aux_tensor_dict=None): self._init_from_aux_dict(self.aux_tensor_dict) def _init_from_env_info(self, env_info): + # TODO: Review and update for dictinary observation spaces obs_base_shape = self.obs_base_shape state_base_shape = self.state_base_shape self.tensor_dict['obses'] = self._create_tensor_from_space(env_info['observation_space'], obs_base_shape) + # print("obs base shape", obs_base_shape) + # print('obses shape:', self.tensor_dict['obses'].shape) + # print('proprioception_space shape:', env_info.get('proprioception_space')) + # if env_info.get('proprieception_space') is not None: + # self.tensor_dict['proprio'] = self._create_tensor_from_space(env_info['proprioception_space'], self.obs_base_shape) if self.has_central_value: self.tensor_dict['states'] = self._create_tensor_from_space(env_info['state_space'], state_base_shape) @@ -373,13 +379,16 @@ def _create_tensor_from_space(self, space, base_shape): return t_dict def update_data(self, name, index, val): + print('name:', name) + print(self.tensor_dict.keys()) + print(self.tensor_dict[name].shape) + print(self.tensor_dict["obses"].shape) if type(val) is dict: for k,v in val.items(): self.tensor_dict[name][k][index,:] = v else: self.tensor_dict[name][index,:] = val - def update_data_rnn(self, name, indices,play_mask, val): if type(val) is dict: for k,v in val: