@@ -49,13 +49,15 @@ def __init__(self, data, dtype=np.float):
49
49
derive from this matrix will share the elementary tape and hist and this behavior can be controled by causal
50
50
parameter.
51
51
'''
52
- self ._elementary_tape = [[], []]
53
- self ._elementary_hist = []
52
+ self .__elementary_tape = [[], []]
53
+ self .__tape = []
54
+ self .__tape_hist = []
54
55
55
56
def refresh (self ):
56
57
self .matrix = np .array (self ._origin_data , dtype = self ._dtype )
57
- self ._elementary_tape = [[], []]
58
- self ._elementary_hist = []
58
+ self .__elementary_tape = [[], []]
59
+ self .__tape = []
60
+ self .__tape_hist = []
59
61
60
62
def get_origin (self ):
61
63
return np .array (self ._origin_data , dtype = self ._dtype )
@@ -74,9 +76,8 @@ def __repr__(self):
74
76
def __getitem__ (self , key ):
75
77
key = index_mechanism (* key )
76
78
data = self .matrix .__getitem__ (key )
77
- if not isinstance (data , np .ndarray ):
78
- data = np .array (data ).reshape (1 )
79
-
79
+ # if not isinstance(data, np.ndarray):
80
+ # data = np.array(data).reshape(1)
80
81
return Matrix (data , dtype = self .matrix .dtype )
81
82
82
83
def __setitem__ (self , key , value ):
@@ -99,7 +100,7 @@ def __mul__(self, other):
99
100
result = self .matrix @ other .matrix
100
101
return Matrix (result , dtype = result .dtype )
101
102
elif isinstance (other , (int , float )):
102
- result = self .matrix * other
103
+ result = self .matrix * other
103
104
return Matrix (result , dtype = result .dtype )
104
105
else :
105
106
raise Exception ('no defination' )
@@ -134,50 +135,72 @@ def to_scalar(self):
134
135
135
136
@property
136
137
def T (self ):
137
- matrix = copy (self , causal = False )
138
- return F .transpose (matrix )
138
+ from matrixstitcher .transform import Transpose
139
+ # matrix = copy(self, causal=False)
140
+ result = Transpose ()(self )
141
+ return result
139
142
140
- def update_tape (self , transform_method , * args , ** kwargs ):
141
- from matrixstitcher . transform import __support_tape__
142
-
143
- if transform_method in __support_tape__ :
144
- determined = transform_method . split ( '_' )[ 0 ] .lower ()
143
+ def update_tape (self , transform , * args , ** kwargs ):
144
+ transform_name = transform . __class__ . __name__
145
+
146
+ if transform . is_elementary () :
147
+ determined = transform_name .lower ()
145
148
if 'row' in determined or 'column' in determined :
146
149
direction = 'row' if 'row' in determined else 'column'
147
150
size = self .shape [self ._direction [direction ]]
148
151
149
152
elementary = Matrix (np .eye (size ), dtype = self ._dtype )
150
- elementary = getattr (F , transform_method )(elementary , * args , ** kwargs )
151
- self ._elementary_tape [self ._direction [direction ]].append (elementary )
152
- method_name = '' .join ([i [0 ].upper () + i [1 :] for i in transform_method .split ('_' )])
153
- self ._elementary_hist .append (transform_template (method_name , args , kwargs ))
153
+ with no_tape ():
154
+ elementary = transform (elementary )
155
+ self .__elementary_tape [self ._direction [direction ]].append (elementary )
156
+
157
+ self .__tape .append (transform )
158
+ self .__tape_hist .append (get_transform_template (transform_name , args , kwargs ))
154
159
155
160
def get_elementary (self ):
156
- return self ._elementary_tape [0 ][::- 1 ], self ._elementary_tape [1 ]
161
+ return self .__elementary_tape [0 ][::- 1 ], self .__elementary_tape [1 ]
162
+
163
+ def get_transform_tape (self ):
164
+ return self .__tape , self .__tape_hist
165
+
166
+ def set_elementary (self , * args ):
167
+ args0 = args [0 ]
168
+ args1 = args [1 ]
169
+ args0 = args0 [::- 1 ]
170
+ self .__elementary_tape = [args0 , args1 ]
171
+
172
+ def set_transform_tape (self , * args ):
173
+ self .__tape = args [0 ]
174
+ self .__tape_hist = args [1 ]
157
175
158
176
def forward (self , causal = True , display = False ):
159
- left_tape , right_tape = self .get_elementary ()
160
-
161
- if not display :
162
- foward_tape = left_tape + [self ] + right_tape
163
- if len (foward_tape ) > 1 :
164
- result = reduce (lambda x , y : x * y , foward_tape )
165
- else :
166
- result = foward_tape [0 ]
167
- else :
168
- i , j = 0 , 0
169
- result = self
170
- print ('-> Origin matrix:\n {}\n ' .format (result ))
171
- for idx , method in enumerate (self ._elementary_hist , 1 ):
172
- if 'row' in method .lower ():
173
- result = self ._elementary_tape [self ._direction ['row' ]][i ] * result
174
- i += 1
175
- elif 'column' in method .lower ():
176
- result = result * self ._elementary_tape [self ._direction ['column' ]][j ]
177
- j += 1
178
- else :
179
- raise Exception ('An illegal method in the elementary tape history' )
180
- print ('-> Stage {}, {}:\n {}\n ' .format (idx , method , result ))
177
+ # Have been deprecated for the changing of tape semantics
178
+
179
+ # left_tape, right_tape = self.get_elementary()
180
+ # if not display:
181
+ # foward_tape = left_tape + [self] + right_tape
182
+ # if len(foward_tape) > 1:
183
+ # result = reduce(lambda x, y: x * y, foward_tape)
184
+ # else:
185
+ # result = foward_tape[0]
186
+ # else:
187
+ # i, j = 0, 0
188
+ # result = self
189
+ # print('-> Origin matrix:\n{}\n'.format(result))
190
+ # for idx, method in enumerate(self._elementary_hist, 1):
191
+ # if 'row' in method.lower():
192
+ # result = self._elementary_tape[self._direction['row']][i] * result
193
+ # i += 1
194
+ # elif 'column' in method.lower():
195
+ # result = result * self._elementary_tape[self._direction['column']][j]
196
+ # j += 1
197
+ # else:
198
+ # raise Exception('An illegal method in the elementary tape history')
199
+ # print('-> Stage {}, {}:\n{}\n'.format(idx, method, result))
200
+
201
+ pipeline , _ = self .get_transform_tape ()
202
+ with no_tape ():
203
+ result = self .apply (pipeline , display = display , forward = True )
181
204
182
205
# Manual operation
183
206
if causal :
@@ -195,24 +218,6 @@ def numpy(self):
195
218
return self .matrix .reshape (- 1 )
196
219
else :
197
220
return self .matrix
198
- # have been deprecated
199
- # def row_transform(*args):
200
- # return F.row_transform(*args)
201
-
202
- # def column_transform(*args):
203
- # return F.column_transform(*args)
204
-
205
- # def row_swap(*args):
206
- # return F.row_swap(*args)
207
-
208
- # def column_swap(*args):
209
- # return F.column_swap(*args)
210
-
211
- # def row_mul(*args):
212
- # return F.row_mul(*args)
213
-
214
- # def column_mul(*args):
215
- # return F.column_mul(*args)
216
221
217
222
218
223
def index_mechanism (* key ):
@@ -227,30 +232,34 @@ def slice_mechanism(key: slice):
227
232
return slice (start , stop , step )
228
233
229
234
230
- def transform_template (p , _args , _kwargs ):
235
+ def get_transform_template (p , * _args , ** _kwargs ):
231
236
template = '{}{}' .format (p , _args + tuple ('{}={}' .format (i , _kwargs [i ]) for i in _kwargs ))
232
237
return template
233
238
234
239
235
- def apply_pipeline (matrix : Matrix , pipeline , display = False ):
240
+ def apply_pipeline (matrix : Matrix , pipeline , display = False , forward = False ):
236
241
'''
237
242
A list or tuple of tranforms to apply on the input matrix.
238
243
'''
239
244
from matrixstitcher .transform import Transform
245
+
240
246
assert isinstance (pipeline , (list , tuple , Transform ))
241
- done_pipeline = len (matrix ._elementary_tape [0 ] + matrix ._elementary_tape [1 ])
247
+ if forward :
248
+ done_pipeline = 0
249
+ else :
250
+ done_pipeline = len (matrix .get_transform_tape ()[0 ])
242
251
243
252
if isinstance (pipeline , Transform ):
244
253
pipeline = [pipeline ]
245
254
246
255
if display :
247
256
if done_pipeline == 0 :
248
257
print ('-> Origin matrix:\n {}\n ' .format (matrix ))
249
- for idx , p in enumerate (pipeline , done_pipeline + 1 ):
258
+ for idx , p in enumerate (pipeline , done_pipeline + 1 ):
250
259
assert isinstance (p , Transform )
251
260
matrix = p (matrix )
252
- transform_template_ = transform_template ( p . __class__ . __name__ , p . _args , p . _kwargs )
253
- print ('-> Stage {}, {}:\n {}\n ' .format (idx , transform_template_ , matrix ))
261
+ transform_template = repr ( p )
262
+ print ('-> Stage {}, {}:\n {}\n ' .format (idx , transform_template , matrix ))
254
263
else :
255
264
for p in pipeline :
256
265
assert isinstance (p , Transform )
@@ -263,6 +272,17 @@ def copy(matrix: Matrix, causal=True):
263
272
264
273
# Manual operation
265
274
if causal :
266
- new_matrix ._elementary_tape = matrix ._elementary_tape
267
- new_matrix ._elementary_hist = matrix ._elementary_hist
268
- return new_matrix
275
+ new_matrix .set_elementary (* matrix .get_elementary ())
276
+ new_matrix .set_transform_tape (* matrix .get_transform_tape ())
277
+ return new_matrix
278
+
279
+
280
+ class no_tape :
281
+ def __enter__ (self ):
282
+ from matrixstitcher .transform import Transform
283
+ self .prev = Transform .is_tape_enabled ()
284
+ Transform .set_tape_enabled (False )
285
+
286
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
287
+ from matrixstitcher .transform import Transform
288
+ Transform .set_tape_enabled (self .prev )
0 commit comments