Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit 348f705

Browse files
author
Omegastick
committed
Add setters to RolloutStorage
1 parent ae03039 commit 348f705

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

include/cpprl/storage.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,16 @@ class RolloutStorage
5959
{
6060
return value_predictions;
6161
}
62+
inline void set_actions(torch::Tensor actions) { this->actions = actions; }
63+
inline void set_action_log_probs(torch::Tensor action_log_probs) { this->action_log_probs = action_log_probs; }
64+
inline void set_hidden_states(torch::Tensor hidden_states) { this->hidden_states = hidden_states; }
65+
inline void set_masks(torch::Tensor masks) { this->masks = masks; }
66+
inline void set_observations(torch::Tensor observations) { this->observations = observations; }
67+
inline void set_returns(torch::Tensor returns) { this->returns = returns; }
68+
inline void set_rewards(torch::Tensor rewards) { this->rewards = rewards; }
69+
inline void set_value_predictions(torch::Tensor value_predictions)
70+
{
71+
this->value_predictions = value_predictions;
72+
}
6273
};
6374
}

0 commit comments

Comments
 (0)