Skip to content

Update visualization #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 67 additions & 32 deletions bayesml/contexttree/_contexttree.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node):
else:
node.h_g = self.h_g
node.h_beta_vec[:] = self.h_beta_vec
if node.depth == self.c_d_max or self.rng.random() > self.h_g: # 葉ノード
if node.depth == self.c_d_max or self.rng.random() > self.h_g: # leaf node
node.theta_vec[:] = self.rng.dirichlet(self.h_beta_vec)
node.leaf = True
else: # 内部ノード
else: # inner node
node.leaf = False
for i in range(self.c_k):
if node.children[i] is None:
Expand All @@ -122,10 +122,10 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node):
else:
node.h_g = h_node.h_g
node.h_beta_vec[:] = h_node.h_beta_vec
if node.depth == self.c_d_max or self.rng.random() > h_node.h_g: # 葉ノード
if node.depth == self.c_d_max or self.rng.random() > h_node.h_g: # leaf node
node.theta_vec[:] = self.rng.dirichlet(h_node.h_beta_vec)
node.leaf = True
else: # 内部ノード
else: # inner node
node.leaf = False
for i in range(self.c_k):
if node.children[i] is None:
Expand All @@ -140,9 +140,9 @@ def _gen_params_recursion_tree_fix(self,node:_Node,h_node:_Node):
else:
node.h_g = self.h_g
node.h_beta_vec[:] = self.h_beta_vec
if node.leaf: # 葉ノード
if node.leaf: # leaf node
node.theta_vec[:] = self.rng.dirichlet(self.h_beta_vec)
else: # 内部ノード
else: # inner node
for i in range(self.c_k):
if node.children[i] is not None:
self._gen_params_recursion_tree_fix(node.children[i],None)
Expand All @@ -152,9 +152,9 @@ def _gen_params_recursion_tree_fix(self,node:_Node,h_node:_Node):
else:
node.h_g = h_node.h_g
node.h_beta_vec[:] = h_node.h_beta_vec
if node.leaf: # 葉ノード
if node.leaf: # leaf node
node.theta_vec[:] = self.rng.dirichlet(h_node.h_beta_vec)
else: # 内部ノード
else: # inner node
for i in range(self.c_k):
if node.children[i] is not None:
self._gen_params_recursion_tree_fix(node.children[i],h_node.children[i])
Expand All @@ -170,7 +170,7 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
a object from _Node class
"""
node.theta_vec[:] = original_tree_node.theta_vec
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
node.leaf = True
else:
node.leaf = False
Expand Down Expand Up @@ -201,7 +201,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
else:
node.h_g = original_tree_node.h_g
node.h_beta_vec[:] = original_tree_node.h_beta_vec
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
node.leaf = True
if node.depth == self.c_d_max:
node.h_g = 0
Expand All @@ -223,7 +223,7 @@ def _gen_sample_recursion(self,node,x):
x : numpy ndarray
1 dimensional array whose elements are 0 or 1.
"""
if node.leaf: # 葉ノード
if node.leaf: # leaf node
return self.rng.choice(self.c_k,p=node.theta_vec)
else:
return self._gen_sample_recursion(node.children[x[-node.depth-1]],x)
Expand Down Expand Up @@ -433,15 +433,14 @@ def visualize_model(self,filename=None,format=None,sample_length=10):
--------
graphviz.Digraph
"""
#例外処理
_check.pos_int(sample_length,'sample_length',DataFormatError)

try:
import graphviz
tree_graph = graphviz.Digraph(filename=filename,format=format)
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
self._visualize_model_recursion(tree_graph, self.root, 0, None, None, 1.0)
# コンソール上で表示できるようにした方がいいかもしれない.
# Can we show the image on the console without saving the file?
tree_graph.view()
except ImportError as e:
print(e)
Expand Down Expand Up @@ -532,7 +531,7 @@ def _set_h0_params_recursion(self,node:_Node,original_tree_node:_Node):
else:
node.h_g = original_tree_node.h_g
node.h_beta_vec[:] = original_tree_node.h_beta_vec
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
node.leaf = True
if node.depth == self.c_d_max:
node.h_g = 0
Expand Down Expand Up @@ -565,7 +564,7 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
else:
node.h_g = original_tree_node.h_g
node.h_beta_vec[:] = original_tree_node.h_beta_vec
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
node.leaf = True
if node.depth == self.c_d_max:
node.h_g = 0
Expand Down Expand Up @@ -610,7 +609,8 @@ def set_h0_params(self,
raise(ParameterFormatError(
"h0_root must be an instance of contexttree._Node"
))
self.h0_root = _Node(0,self.c_k)
if self.h0_root is None:
self.h0_root = _Node(0,self.c_k)
self._set_h0_params_recursion(self.h0_root,h0_root)

self.reset_hn_params()
Expand Down Expand Up @@ -663,7 +663,8 @@ def set_hn_params(self,
raise(ParameterFormatError(
"hn_root must be an instance of contexttree._Node"
))
self.hn_root = _Node(0,self.c_k)
if self.hn_root is None:
self.hn_root = _Node(0,self.c_k)
self._set_hn_params_recursion(self.hn_root,hn_root)

self.calc_pred_dist(np.zeros(self.c_d_max,dtype=int))
Expand All @@ -688,7 +689,7 @@ def _update_posterior_leaf(self,node:_Node,x,i):
return tmp

def _update_posterior_recursion(self,node:_Node,x,i):
if node.depth < self.c_d_max and i-1-node.depth >= 0: # 内部ノード
if node.depth < self.c_d_max and i-1-node.depth >= 0: # inner node
if node.children[x[i-node.depth-1]] is None:
node.children[x[i-node.depth-1]] = _Node(
node.depth+1,
Expand All @@ -703,7 +704,7 @@ def _update_posterior_recursion(self,node:_Node,x,i):
tmp2 = (1 - node.h_g) * self._update_posterior_leaf(node,x,i) + node.h_g * tmp1
node.h_g = node.h_g * tmp1 / tmp2
return tmp2
else: # 葉ノード
else: # leaf node
return self._update_posterior_leaf(node,x,i)

def update_posterior(self,x):
Expand All @@ -728,11 +729,11 @@ def update_posterior(self,x):
self._update_posterior_recursion(self.hn_root,x,i)

def _map_recursion_add_nodes(self,node:_Node):
if node.depth == self.c_d_max: # 葉ノード
if node.depth == self.c_d_max: # leaf node
node.h_g = 0.0
node.leaf = True
node.map_leaf = True
else: # 内部ノード
else: # inner node
for i in range(self.c_k):
node.children[i] = _Node(node.depth+1,self.c_k)
node.children[i].h_g = self.hn_g
Expand Down Expand Up @@ -817,6 +818,7 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
tree_graph = graphviz.Digraph(filename=filename,format=format)
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
self._visualize_model_recursion(tree_graph, map_root, 0, None, None, 1.0)
# Can we show the image on the console without saving the file?
tree_graph.view()
return map_root
else:
Expand Down Expand Up @@ -851,6 +853,38 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl

return node_id

def _visualize_model_recursion_none(self,tree_graph,depth,node_id,parent_id,sibling_num,p_v):
tmp_id = node_id
tmp_p_v = p_v

# add node information
if depth == self.c_d_max:
label_string = 'hn_g=0\\l'
else:
label_string = f'hn_g={self.hn_g:.2f}\\l'
label_string += f'p_v={tmp_p_v:.2f}\\ltheta_vec\\l='
label_string += '['
for i in range(self.c_k):
theta_vec_hat = self.hn_beta_vec / self.hn_beta_vec.sum()
label_string += f'{theta_vec_hat[i]:.2f}'
if i < self.c_k-1:
label_string += ','
label_string += ']'

tree_graph.node(name=f'{tmp_id}',label=label_string,fillcolor=f'{rgb2hex(_CMAP(tmp_p_v))}')
if tmp_p_v > 0.65:
tree_graph.node(name=f'{tmp_id}',fontcolor='white')

# add edge information
if parent_id is not None:
tree_graph.edge(f'{parent_id}', f'{tmp_id}', label=f'{sibling_num}')

if depth < self.c_d_max:
for i in range(self.c_k):
node_id = self._visualize_model_recursion_none(tree_graph,depth+1,node_id+1,tmp_id,i,tmp_p_v*self.hn_g)

return node_id

def visualize_posterior(self,filename=None,format=None):
"""Visualize the posterior distribution for the parameter.

Expand Down Expand Up @@ -883,8 +917,11 @@ def visualize_posterior(self,filename=None,format=None):
import graphviz
tree_graph = graphviz.Digraph(filename=filename,format=format)
tree_graph.attr("node",shape="box",fontname="helvetica",style="rounded,filled")
self._visualize_model_recursion(tree_graph, self.hn_root, 0, None, None, 1.0)
# コンソール上で表示できるようにした方がいいかもしれない.
if self.hn_root is None:
self._visualize_model_recursion_none(tree_graph, 0, 0, None, None, 1.0)
else:
self._visualize_model_recursion(tree_graph, self.hn_root, 0, None, None, 1.0)
# Can we show the image on the console without saving the file?
tree_graph.view()
except ImportError as e:
print(e)
Expand All @@ -905,7 +942,7 @@ def _calc_pred_dist_leaf(self,node:_Node):
return node.h_beta_vec / node.h_beta_vec.sum()

def _calc_pred_dist_recursion(self,node:_Node,x,i):
if node.depth < self.c_d_max and i-1-node.depth >= 0: # 内部ノード
if node.depth < self.c_d_max and i-1-node.depth >= 0: # inner node
if node.children[x[i-node.depth-1]] is None:
node.children[x[i-node.depth-1]] = _Node(
node.depth+1,
Expand All @@ -919,7 +956,7 @@ def _calc_pred_dist_recursion(self,node:_Node,x,i):
tmp1 = self._calc_pred_dist_recursion(node.children[x[i-node.depth-1]],x,i)
tmp2 = (1 - node.h_g) * self._calc_pred_dist_leaf(node) + node.h_g * tmp1
return tmp2
else: # 葉ノード
else: # leaf node
return self._calc_pred_dist_leaf(node)

def calc_pred_dist(self,x):
Expand All @@ -936,11 +973,9 @@ def calc_pred_dist(self,x):
i = x.shape[0] - 1

if self.hn_root is None:
self.hn_root = _Node(0,self.c_k)
self.hn_root.h_g = self.hn_g
self.hn_root.h_beta_vec[:] = self.hn_beta_vec

self.p_theta_vec[:] = self._calc_pred_dist_recursion(self.hn_root,x,i)
self.p_theta_vec[:] = self.hn_beta_vec / self.hn_beta_vec.sum()
else:
self.p_theta_vec[:] = self._calc_pred_dist_recursion(self.hn_root,x,i)

def make_prediction(self,loss="KL"):
"""Predict a new data point under the given criterion.
Expand Down Expand Up @@ -973,7 +1008,7 @@ def _pred_and_update_leaf(self,node:_Node,x,i):
return tmp

def _pred_and_update_recursion(self,node:_Node,x,i):
if node.depth < self.c_d_max and i-1-node.depth >= 0: # 内部ノード
if node.depth < self.c_d_max and i-1-node.depth >= 0: # inner node
if node.children[x[i-node.depth-1]] is None:
node.children[x[i-node.depth-1]] = _Node(
node.depth+1,
Expand All @@ -988,7 +1023,7 @@ def _pred_and_update_recursion(self,node:_Node,x,i):
tmp2 = (1 - node.h_g) * self._pred_and_update_leaf(node,x,i) + node.h_g * tmp1
node.h_g = node.h_g * tmp1[x[i]] / tmp2[x[i]]
return tmp2
else: # 葉ノード
else: # leaf node
return self._pred_and_update_leaf(node,x,i)

def pred_and_update(self,x,loss="KL"):
Expand Down