@@ -78,8 +78,10 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace,
78
78
self .rule_trace_edge = numba .typed .List .empty_list (numba .types .Tuple ((numba .types .uint16 , numba .types .uint16 , edge_type , label .label_type , interval .interval_type )))
79
79
80
80
# Nodes and edges of the graph
81
- self .nodes = numba .typed .List (self .graph .nodes ())
82
- self .edges = numba .typed .List (self .graph .edges ())
81
+ self .nodes = numba .typed .List .empty_list (node_type )
82
+ self .edges = numba .typed .List .empty_list (edge_type )
83
+ self .nodes .extend (numba .typed .List (self .graph .nodes ()))
84
+ self .edges .extend (numba .typed .List (self .graph .edges ()))
83
85
84
86
# Make sure they are correct type
85
87
if len (self .available_labels_node )== 0 :
@@ -91,8 +93,8 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace,
91
93
else :
92
94
self .available_labels_edge = numba .typed .List (self .available_labels_edge )
93
95
94
- self .interpretations_node = self ._init_interpretations_node (numba . typed . List ( self .graph . nodes ()) , self .available_labels_node , self .specific_node_labels )
95
- self .interpretations_edge = self ._init_interpretations_edge (numba . typed . List ( self .graph . edges ()) , self .available_labels_edge , self .specific_edge_labels )
96
+ self .interpretations_node = self ._init_interpretations_node (self .nodes , self .available_labels_node , self .specific_node_labels )
97
+ self .interpretations_edge = self ._init_interpretations_edge (self .edges , self .available_labels_edge , self .specific_edge_labels )
96
98
97
99
# Setup graph neighbors and reverse neighbors
98
100
self .neighbors = numba .typed .Dict .empty (key_type = node_type , value_type = numba .types .ListType (node_type ))
@@ -687,6 +689,10 @@ def delete_edge(self, edge):
687
689
# This function is useful for pyreason gym, called externally
688
690
_delete_edge (edge , self .neighbors , self .reverse_neighbors , self .edges , self .interpretations_edge )
689
691
692
+ def delete_node (self , node ):
693
+ # This function is useful for pyreason gym, called externally
694
+ _delete_node (node , self .neighbors , self .reverse_neighbors , self .nodes , self .interpretations_node )
695
+
690
696
def get_interpretation_dict (self ):
691
697
# This function can be called externally to retrieve a dict of the interpretation values
692
698
# Only values in the rule trace will be added
@@ -1928,6 +1934,22 @@ def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge
1928
1934
reverse_neighbors [target ].remove (source )
1929
1935
1930
1936
1937
+ @numba .njit (cache = True )
1938
+ def _delete_node (node , neighbors , reverse_neighbors , nodes , interpretations_node ):
1939
+ nodes .remove (node )
1940
+ del interpretations_node [node ]
1941
+ del neighbors [node ]
1942
+ del reverse_neighbors [node ]
1943
+
1944
+ # Remove all occurrences of node in neighbors
1945
+ for n in neighbors .keys ():
1946
+ if node in neighbors [n ]:
1947
+ neighbors [n ].remove (node )
1948
+ for n in reverse_neighbors .keys ():
1949
+ if node in reverse_neighbors [n ]:
1950
+ reverse_neighbors [n ].remove (node )
1951
+
1952
+
1931
1953
@numba .njit (cache = True )
1932
1954
def float_to_str (value ):
1933
1955
number = int (value )
0 commit comments