@@ -57,6 +57,22 @@ def __init__(self, number_of_clauses, T, s, boost_true_positive_feedback=1, numb
57
57
self .prepare_encode = mod_encode .get_function ("prepare_encode" )
58
58
self .encode = mod_encode .get_function ("encode" )
59
59
60
+ def __getstate__ (self ):
61
+ state = self .__dict__ .copy ()
62
+ state .ta_state = np .empty (self .number_of_classes * self .number_of_clauses * self .number_of_ta_chunks * self .number_of_state_bits ).astype (np .uint32 )
63
+ cuda .memcpy_dtoh (state .ta_state , self .ta_state_gpu )
64
+ state .clause_weights = np .empty (self .number_of_classes * self .number_of_clauses ).astype (np .uint8 )
65
+ cuda .memcpy_dtoh (sate .clause_weights , self .clause_weights_gpu )
66
+
67
+ print (state .keys ())
68
+ xxx
69
+ return state
70
+
71
+ def __setstate__ (self , state ):
72
+ self .__dict__ .update (state )
73
+ self .mc_ctm = _lib .CreateMultiClassTsetlinMachine (self .number_of_classes , self .number_of_clauses , self .number_of_features , self .number_of_patches , self .number_of_ta_chunks , self .number_of_state_bits , self .T , self .s , self .s_range , self .boost_true_positive_feedback , self .weighted_clauses , self .clause_drop_p , self .literal_drop_p )
74
+ self .set_state (state ['mc_ctm_state' ])
75
+
60
76
def encode_X (self , X , encoded_X_gpu ):
61
77
number_of_examples = X .shape [0 ]
62
78
@@ -89,11 +105,11 @@ def ta_action(self, mc_tm_class, clause, ta):
89
105
return (ta_state [mc_tm_class , clause , ta // 32 , self .number_of_state_bits - 1 ] & (1 << (ta % 32 ))) > 0
90
106
91
107
def get_state (self ):
92
- if np .array_equal (self .clause_weights , np . array ([])):
93
- self .ta_state = np . empty ( self .number_of_classes * self . number_of_clauses * self . number_of_ta_chunks * self . number_of_state_bits ). astype ( np . uint32 )
94
- cuda . memcpy_dtoh (self .ta_state , self .ta_state_gpu )
95
- self .clause_weights = np . empty ( self .number_of_classes * self . number_of_clauses ). astype ( np . uint8 )
96
- cuda . memcpy_dtoh ( self . clause_weights , self . clause_weights_gpu )
108
+ self . ta_state = np .empty (self .number_of_classes * self . number_of_clauses * self . number_of_ta_chunks * self . number_of_state_bits ). astype ( np . uint32 )
109
+ cuda . memcpy_dtoh ( self .ta_state , self .ta_state_gpu )
110
+ self . clause_weights = np . empty (self .number_of_classes * self .number_of_clauses ). astype ( np . uint8 )
111
+ cuda . memcpy_dtoh ( self .clause_weights , self .clause_weights_gpu )
112
+
97
113
return ((self .ta_state , self .clause_weights , self .number_of_classes , self .number_of_clauses , self .number_of_features , self .dim , self .patch_dim , self .number_of_patches , self .number_of_state_bits , self .max_weight , self .number_of_ta_chunks , self .append_negated , self .min_y , self .max_y ))
98
114
99
115
def set_state (self , state ):
0 commit comments