Skip to content

Commit 61150cc

Browse files
committed
Update
1 parent e22269c commit 61150cc

File tree

1 file changed

+11
-1
lines changed
  • PyTsetlinMachineCUDA

1 file changed

+11
-1
lines changed

PyTsetlinMachineCUDA/tm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,18 @@ def __getstate__(self):
9494

9595
def __setstate__(self, state):
9696
self.__dict__.update(state)
97-
self.set_state((self.ta_state, self.clause_weights, self.number_of_classes, self.number_of_clauses, self.number_of_features, self.dim, self.patch_dim, self.number_of_patches, self.number_of_state_bits, self.max_weight, self.number_of_ta_chunks, self.append_negated, self.min_y, self.max_y))
97+
self.X_train = np.array([])
98+
self.Y_train = np.array([])
99+
self.X_test = np.array([])
100+
101+
mod_encode = SourceModule(kernels.code_encode, no_extern_c=True)
102+
self.prepare_encode = mod_encode.get_function("prepare_encode")
103+
self.encode = mod_encode.get_function("encode")
98104

105+
self.ta_state_gpu = cuda.mem_alloc(self.number_of_classes*self.number_of_clauses*self.number_of_ta_chunks*self.number_of_state_bits*4)
106+
self.clause_weights_gpu = cuda.mem_alloc(self.number_of_classes*self.number_of_clauses)
107+
cuda.memcpy_htod(self.ta_state_gpu, self.ta_state)
108+
cuda.memcpy_htod(self.clause_weights_gpu, self.clause_weights)
99109

100110
def encode_X(self, X, encoded_X_gpu):
101111
number_of_examples = X.shape[0]

0 commit comments

Comments
 (0)