-
Notifications
You must be signed in to change notification settings - Fork 4
/
nodes.h
99 lines (75 loc) · 2.51 KB
/
nodes.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#ifndef NODES_H_
#define NODES_H_
#include <iostream>
#include <vector>
using namespace std;
class MultiNode;
class ROOT{
public :
vector<double> edge_weights;
vector<MultiNode> children;
vector<int> maxind;
int leafstart;
vector<double> orig_edge_weights;
double edgesum;
double orig_edgesum;
ROOT();
ROOT(vector<double> edge_weights, vector<MultiNode> children, vector<int> maxind,
int leafstart, double edgesum, vector<double> orig_edge_weights, double orig_edgesum);
int num_leaves();
void sample_node();
void leaf_count_update(double val, int leaf);
double wordval_update(double val, int leaf);
double logphi_update();
vector<MultiNode> get_multinodes();
};
class Node { //represent a node in dirichlet tree
public:
vector<double> edge_weights;
vector<Node> children;
vector<int> maxind;
int leafstart;
vector<double> orig_edge_weights;
double edgesum;
double orig_edgesum;
Node(vector<double> edge_weights, vector<Node> children,
vector<int> maxind, int leafstart, double edgesum,
vector<double> orig_edge_weights, double orig_edgesum);
vector<int> words;
Node(vector<double> edge_weights, vector<Node> children, vector<int> maxind,
int leafstart, vector<int> words,
double edgesum, vector<double> orig_edge_weights, double orig_edgesum);
int num_leaves();
void sample_node();
void leaf_count_update(double val, int leaf);
double wordval_update(double val, int leaf);
double logphi_update();
};
class MultiNode{ //represent an intermediate node
public:
vector<double> edge_weights;
vector<int> maxind;
int leafstart;
vector<Node> children;
vector<double> orig_edge_weights;
double edgesum;
double orig_edgesum;
vector<int> words;
vector<Node> variants;
vector<double> variant_logweights;
vector<vector<int>> fake_leafmap;
int y;
MultiNode(vector<double> edge_weights, vector<Node> children, vector<int> maxind,
int leafstart, double edgesum, vector<double> orig_edge_weights, double orig_edgesum);
MultiNode(vector<double> edge_weights, vector<Node> children, vector<int> maxind,
int leafstart, vector<int> words, vector<Node> variants, vector<vector<int>> fake_leafmap,
vector<double> variant_logweights);
void leaf_count_update(double val, int leaf);
double wordval_update(double val, int leaf);
int num_variants();
double var_logweight(int given_y);
double logphi_update(int given_y);
int num_leaves();
double logphi_update();
};
#endif /* NODES_H_ */