Skip to content

Commit

Permalink
good use
Browse files Browse the repository at this point in the history
  • Loading branch information
rongkunxue committed Mar 28, 2024
1 parent 8ab5da8 commit b12714e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
1 change: 0 additions & 1 deletion ding/model/template/qtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(self, input_dim):
nn.ReLU(),
nn.Linear(256, 512)
)

def forward(self, x):
x = self.layers(x)
x = x.unsqueeze(1)
Expand Down
8 changes: 2 additions & 6 deletions ding/policy/qtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
q_next = self._target_model.forward(next_obs)
# get target Q
q_target_all_actions = self._target_model.forward(states, actions = actions)

q_next = q_next.max(dim = -1).values
q_next.clamp_(min = -100)
q_target = q_target_all_actions.max(dim = -1).values
Expand All @@ -373,20 +374,15 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
q_intermediates = QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)
num_timesteps = actions.shape[1]
batch = actions.shape[0]

q_preds = q_intermediates.q_pred_all_actions
q_preds = rearrange(q_preds, '... a -> (...) a')

num_action_bins = q_preds.shape[-1]
num_non_dataset_actions = num_action_bins - 1

actions = rearrange(actions, '... -> (...) 1')

dataset_action_mask = torch.zeros_like(q_preds).scatter_(-1, actions, torch.ones_like(q_preds))

q_actions_not_taken = q_preds[~dataset_action_mask.bool()]
q_actions_not_taken = rearrange(q_actions_not_taken, '(b t a) -> b t a', b = batch, a = num_non_dataset_actions)

conservative_reg_loss = ((q_actions_not_taken - (self._cfg.learn["min_reward"] * num_timesteps)) ** 2).sum() / num_non_dataset_actions
# total loss
loss_dict['loss']=0.5 * td_loss + 0.5 * conservative_reg_loss
Expand Down

0 comments on commit b12714e

Please sign in to comment.