Skip to content

Commit c0da3f9

Browse files
committed
Started adding pickle support
1 parent 2e2666f commit c0da3f9

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

PyTsetlinMachineCUDA/tm.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ def __init__(self, number_of_clauses, T, s, boost_true_positive_feedback=1, numb
5757
self.prepare_encode = mod_encode.get_function("prepare_encode")
5858
self.encode = mod_encode.get_function("encode")
5959

60+
def __getstate__(self):
61+
state = self.__dict__.copy()
62+
state.ta_state = np.empty(self.number_of_classes*self.number_of_clauses*self.number_of_ta_chunks*self.number_of_state_bits).astype(np.uint32)
63+
cuda.memcpy_dtoh(state.ta_state, self.ta_state_gpu)
64+
state.clause_weights = np.empty(self.number_of_classes*self.number_of_clauses).astype(np.uint8)
65+
cuda.memcpy_dtoh(sate.clause_weights, self.clause_weights_gpu)
66+
67+
print(state.keys())
68+
xxx
69+
return state
70+
71+
def __setstate__(self, state):
72+
self.__dict__.update(state)
73+
self.mc_ctm = _lib.CreateMultiClassTsetlinMachine(self.number_of_classes, self.number_of_clauses, self.number_of_features, self.number_of_patches, self.number_of_ta_chunks, self.number_of_state_bits, self.T, self.s, self.s_range, self.boost_true_positive_feedback, self.weighted_clauses, self.clause_drop_p, self.literal_drop_p)
74+
self.set_state(state['mc_ctm_state'])
75+
6076
def encode_X(self, X, encoded_X_gpu):
6177
number_of_examples = X.shape[0]
6278

@@ -89,11 +105,11 @@ def ta_action(self, mc_tm_class, clause, ta):
89105
return (ta_state[mc_tm_class, clause, ta // 32, self.number_of_state_bits-1] & (1 << (ta % 32))) > 0
90106

91107
def get_state(self):
92-
if np.array_equal(self.clause_weights, np.array([])):
93-
self.ta_state = np.empty(self.number_of_classes*self.number_of_clauses*self.number_of_ta_chunks*self.number_of_state_bits).astype(np.uint32)
94-
cuda.memcpy_dtoh(self.ta_state, self.ta_state_gpu)
95-
self.clause_weights = np.empty(self.number_of_classes*self.number_of_clauses).astype(np.uint8)
96-
cuda.memcpy_dtoh(self.clause_weights, self.clause_weights_gpu)
108+
self.ta_state = np.empty(self.number_of_classes*self.number_of_clauses*self.number_of_ta_chunks*self.number_of_state_bits).astype(np.uint32)
109+
cuda.memcpy_dtoh(self.ta_state, self.ta_state_gpu)
110+
self.clause_weights = np.empty(self.number_of_classes*self.number_of_clauses).astype(np.uint8)
111+
cuda.memcpy_dtoh(self.clause_weights, self.clause_weights_gpu)
112+
97113
return((self.ta_state, self.clause_weights, self.number_of_classes, self.number_of_clauses, self.number_of_features, self.dim, self.patch_dim, self.number_of_patches, self.number_of_state_bits, self.max_weight, self.number_of_ta_chunks, self.append_negated, self.min_y, self.max_y))
98114

99115
def set_state(self, state):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='PyTsetlinMachineCUDA',
5-
version='0.1.8',
5+
version='0.1.9',
66
author='Ole-Christoffer Granmo',
77
author_email='ole.granmo@uia.no',
88
url='https://github.com/cair/pyTsetlinMachineCUDA/',

0 commit comments

Comments
 (0)