diff --git a/ding/model/template/qtransformer.py b/ding/model/template/qtransformer.py index c365ec54aa..e33010a165 100644 --- a/ding/model/template/qtransformer.py +++ b/ding/model/template/qtransformer.py @@ -33,7 +33,6 @@ def __init__(self, input_dim): nn.ReLU(), nn.Linear(256, 512) ) - def forward(self, x): x = self.layers(x) x = x.unsqueeze(1) diff --git a/ding/policy/qtransformer.py b/ding/policy/qtransformer.py index 735f8cb0a8..5a2cfaefe0 100644 --- a/ding/policy/qtransformer.py +++ b/ding/policy/qtransformer.py @@ -350,6 +350,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: q_next = self._target_model.forward(next_obs) # get target Q q_target_all_actions = self._target_model.forward(states, actions = actions) + q_next = q_next.max(dim = -1).values q_next.clamp_(min = -100) q_target = q_target_all_actions.max(dim = -1).values @@ -373,20 +374,15 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: q_intermediates = QIntermediates(q_pred_all_actions, q_pred, q_next, q_target) num_timesteps = actions.shape[1] batch = actions.shape[0] - + q_preds = q_intermediates.q_pred_all_actions q_preds = rearrange(q_preds, '... a -> (...) a') - num_action_bins = q_preds.shape[-1] num_non_dataset_actions = num_action_bins - 1 - actions = rearrange(actions, '... -> (...) 1') - dataset_action_mask = torch.zeros_like(q_preds).scatter_(-1, actions, torch.ones_like(q_preds)) - q_actions_not_taken = q_preds[~dataset_action_mask.bool()] q_actions_not_taken = rearrange(q_actions_not_taken, '(b t a) -> b t a', b = batch, a = num_non_dataset_actions) - conservative_reg_loss = ((q_actions_not_taken - (self._cfg.learn["min_reward"] * num_timesteps)) ** 2).sum() / num_non_dataset_actions # total loss loss_dict['loss']=0.5 * td_loss + 0.5 * conservative_reg_loss