Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed build errors. #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions core/ctree/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ namespace tree{

CRoots::~CRoots(){}

void CRoots::prepare(float root_exploration_fraction, const std::vector<std::vector<float>> &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float>> &policies){
void CRoots::prepare(float root_exploration_fraction, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies){
for(int i = 0; i < this->root_num; ++i){
this->roots[i].expand(0, 0, i, value_prefixs[i], policies[i]);
this->roots[i].add_exploration_noise(root_exploration_fraction, noises[i]);
Expand All @@ -210,7 +210,7 @@ namespace tree{
}
}

void CRoots::prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float>> &policies){
void CRoots::prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies){
for(int i = 0; i < this->root_num; ++i){
this->roots[i].expand(0, 0, i, value_prefixs[i], policies[i]);

Expand All @@ -223,8 +223,8 @@ namespace tree{
this->roots.clear();
}

std::vector<std::vector<int>> CRoots::get_trajectories(){
std::vector<std::vector<int>> trajs;
std::vector<std::vector<int> > CRoots::get_trajectories(){
std::vector<std::vector<int> > trajs;
trajs.reserve(this->root_num);

for(int i = 0; i < this->root_num; ++i){
Expand All @@ -233,8 +233,8 @@ namespace tree{
return trajs;
}

std::vector<std::vector<int>> CRoots::get_distributions(){
std::vector<std::vector<int>> distributions;
std::vector<std::vector<int> > CRoots::get_distributions(){
std::vector<std::vector<int> > distributions;
distributions.reserve(this->root_num);

for(int i = 0; i < this->root_num; ++i){
Expand Down Expand Up @@ -314,7 +314,7 @@ namespace tree{
update_tree_q(root, min_max_stats, discount);
}

void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float>> &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_lst){
void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_lst){
for(int i = 0; i < results.num; ++i){
results.nodes[i]->expand(0, hidden_state_index_x, i, value_prefixs[i], policies[i]);
// reset
Expand Down
14 changes: 7 additions & 7 deletions core/ctree/cnode.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ namespace tree {
public:
int root_num, action_num, pool_size;
std::vector<CNode> roots;
std::vector<std::vector<CNode>> node_pools;
std::vector<std::vector<CNode> > node_pools;

CRoots();
CRoots(int root_num, int action_num, int pool_size);
~CRoots();

void prepare(float root_exploration_fraction, const std::vector<std::vector<float>> &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float>> &policies);
void prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float>> &policies);
void prepare(float root_exploration_fraction, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies);
void prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies);
void clear();
std::vector<std::vector<int>> get_trajectories();
std::vector<std::vector<int>> get_distributions();
std::vector<std::vector<int> > get_trajectories();
std::vector<std::vector<int> > get_distributions();
std::vector<float> get_values();

};
Expand All @@ -64,7 +64,7 @@ namespace tree {
int num;
std::vector<int> hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, search_lens;
std::vector<CNode*> nodes;
std::vector<std::vector<CNode*>> search_paths;
std::vector<std::vector<CNode*> > search_paths;

CSearchResults();
CSearchResults(int num);
Expand All @@ -76,7 +76,7 @@ namespace tree {
//*********************************************************
void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount);
void cback_propagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount);
void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float>> &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_lst);
void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_lst);
int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount, float mean_q);
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount);
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results);
Expand Down