Skip to content

Commit 339436f

Browse files
authored
Merge pull request #1 from jhliu17/tape_semantics
MatrixStitcher New Release r0.2(rc): The New Tape Mechanism
2 parents a42f6ce + 729de76 commit 339436f

File tree

8 files changed

+232
-299
lines changed

8 files changed

+232
-299
lines changed

LeastSquare.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

example/LUFactorization2.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

homework2-1.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

homework2-2.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

matrixstitcher/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from matrixstitcher.backend import Matrix
2-
from matrixstitcher.backend import apply_pipeline as apply
2+
from matrixstitcher.backend import no_tape

matrixstitcher/backend.py

Lines changed: 89 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ def __init__(self, data, dtype=np.float):
4949
derive from this matrix will share the elementary tape and hist and this behavior can be controled by causal
5050
parameter.
5151
'''
52-
self._elementary_tape = [[], []]
53-
self._elementary_hist = []
52+
self.__elementary_tape = [[], []]
53+
self.__tape = []
54+
self.__tape_hist = []
5455

5556
def refresh(self):
5657
self.matrix = np.array(self._origin_data, dtype=self._dtype)
57-
self._elementary_tape = [[], []]
58-
self._elementary_hist = []
58+
self.__elementary_tape = [[], []]
59+
self.__tape = []
60+
self.__tape_hist = []
5961

6062
def get_origin(self):
6163
return np.array(self._origin_data, dtype=self._dtype)
@@ -74,9 +76,8 @@ def __repr__(self):
7476
def __getitem__(self, key):
7577
key = index_mechanism(*key)
7678
data = self.matrix.__getitem__(key)
77-
if not isinstance(data, np.ndarray):
78-
data = np.array(data).reshape(1)
79-
79+
# if not isinstance(data, np.ndarray):
80+
# data = np.array(data).reshape(1)
8081
return Matrix(data, dtype=self.matrix.dtype)
8182

8283
def __setitem__(self, key, value):
@@ -99,7 +100,7 @@ def __mul__(self, other):
99100
result = self.matrix @ other.matrix
100101
return Matrix(result, dtype=result.dtype)
101102
elif isinstance(other, (int, float)):
102-
result = self.matrix * other
103+
result = self.matrix * other
103104
return Matrix(result, dtype=result.dtype)
104105
else:
105106
raise Exception('no defination')
@@ -134,50 +135,72 @@ def to_scalar(self):
134135

135136
@property
136137
def T(self):
137-
matrix = copy(self, causal=False)
138-
return F.transpose(matrix)
138+
from matrixstitcher.transform import Transpose
139+
# matrix = copy(self, causal=False)
140+
result = Transpose()(self)
141+
return result
139142

140-
def update_tape(self, transform_method, *args, **kwargs):
141-
from matrixstitcher.transform import __support_tape__
142-
143-
if transform_method in __support_tape__:
144-
determined = transform_method.split('_')[0].lower()
143+
def update_tape(self, transform, *args, **kwargs):
144+
transform_name = transform.__class__.__name__
145+
146+
if transform.is_elementary():
147+
determined = transform_name.lower()
145148
if 'row' in determined or 'column' in determined:
146149
direction = 'row' if 'row' in determined else 'column'
147150
size = self.shape[self._direction[direction]]
148151

149152
elementary = Matrix(np.eye(size), dtype=self._dtype)
150-
elementary = getattr(F, transform_method)(elementary, *args, **kwargs)
151-
self._elementary_tape[self._direction[direction]].append(elementary)
152-
method_name = ''.join([i[0].upper() + i[1:] for i in transform_method.split('_')])
153-
self._elementary_hist.append(transform_template(method_name, args, kwargs))
153+
with no_tape():
154+
elementary = transform(elementary)
155+
self.__elementary_tape[self._direction[direction]].append(elementary)
156+
157+
self.__tape.append(transform)
158+
self.__tape_hist.append(get_transform_template(transform_name, args, kwargs))
154159

155160
def get_elementary(self):
156-
return self._elementary_tape[0][::-1], self._elementary_tape[1]
161+
return self.__elementary_tape[0][::-1], self.__elementary_tape[1]
162+
163+
def get_transform_tape(self):
164+
return self.__tape, self.__tape_hist
165+
166+
def set_elementary(self, *args):
167+
args0 = args[0]
168+
args1 = args[1]
169+
args0 = args0[::-1]
170+
self.__elementary_tape = [args0, args1]
171+
172+
def set_transform_tape(self, *args):
173+
self.__tape = args[0]
174+
self.__tape_hist = args[1]
157175

158176
def forward(self, causal=True, display=False):
159-
left_tape, right_tape = self.get_elementary()
160-
161-
if not display:
162-
foward_tape = left_tape + [self] + right_tape
163-
if len(foward_tape) > 1:
164-
result = reduce(lambda x, y: x * y, foward_tape)
165-
else:
166-
result = foward_tape[0]
167-
else:
168-
i, j = 0, 0
169-
result = self
170-
print('-> Origin matrix:\n{}\n'.format(result))
171-
for idx, method in enumerate(self._elementary_hist, 1):
172-
if 'row' in method.lower():
173-
result = self._elementary_tape[self._direction['row']][i] * result
174-
i += 1
175-
elif 'column' in method.lower():
176-
result = result * self._elementary_tape[self._direction['column']][j]
177-
j += 1
178-
else:
179-
raise Exception('An illegal method in the elementary tape history')
180-
print('-> Stage {}, {}:\n{}\n'.format(idx, method, result))
177+
# Have been deprecated for the changing of tape semantics
178+
179+
# left_tape, right_tape = self.get_elementary()
180+
# if not display:
181+
# foward_tape = left_tape + [self] + right_tape
182+
# if len(foward_tape) > 1:
183+
# result = reduce(lambda x, y: x * y, foward_tape)
184+
# else:
185+
# result = foward_tape[0]
186+
# else:
187+
# i, j = 0, 0
188+
# result = self
189+
# print('-> Origin matrix:\n{}\n'.format(result))
190+
# for idx, method in enumerate(self._elementary_hist, 1):
191+
# if 'row' in method.lower():
192+
# result = self._elementary_tape[self._direction['row']][i] * result
193+
# i += 1
194+
# elif 'column' in method.lower():
195+
# result = result * self._elementary_tape[self._direction['column']][j]
196+
# j += 1
197+
# else:
198+
# raise Exception('An illegal method in the elementary tape history')
199+
# print('-> Stage {}, {}:\n{}\n'.format(idx, method, result))
200+
201+
pipeline, _ = self.get_transform_tape()
202+
with no_tape():
203+
result = self.apply(pipeline, display=display, forward=True)
181204

182205
# Manual operation
183206
if causal:
@@ -195,24 +218,6 @@ def numpy(self):
195218
return self.matrix.reshape(-1)
196219
else:
197220
return self.matrix
198-
# have been deprecated
199-
# def row_transform(*args):
200-
# return F.row_transform(*args)
201-
202-
# def column_transform(*args):
203-
# return F.column_transform(*args)
204-
205-
# def row_swap(*args):
206-
# return F.row_swap(*args)
207-
208-
# def column_swap(*args):
209-
# return F.column_swap(*args)
210-
211-
# def row_mul(*args):
212-
# return F.row_mul(*args)
213-
214-
# def column_mul(*args):
215-
# return F.column_mul(*args)
216221

217222

218223
def index_mechanism(*key):
@@ -227,30 +232,34 @@ def slice_mechanism(key: slice):
227232
return slice(start, stop, step)
228233

229234

230-
def transform_template(p, _args, _kwargs):
235+
def get_transform_template(p, *_args, **_kwargs):
231236
template = '{}{}'.format(p, _args + tuple('{}={}'.format(i, _kwargs[i]) for i in _kwargs))
232237
return template
233238

234239

235-
def apply_pipeline(matrix: Matrix, pipeline, display=False):
240+
def apply_pipeline(matrix: Matrix, pipeline, display=False, forward=False):
236241
'''
237242
A list or tuple of tranforms to apply on the input matrix.
238243
'''
239244
from matrixstitcher.transform import Transform
245+
240246
assert isinstance(pipeline, (list, tuple, Transform))
241-
done_pipeline = len(matrix._elementary_tape[0] + matrix._elementary_tape[1])
247+
if forward:
248+
done_pipeline = 0
249+
else:
250+
done_pipeline = len(matrix.get_transform_tape()[0])
242251

243252
if isinstance(pipeline, Transform):
244253
pipeline = [pipeline]
245254

246255
if display:
247256
if done_pipeline == 0:
248257
print('-> Origin matrix:\n{}\n'.format(matrix))
249-
for idx, p in enumerate(pipeline, done_pipeline+1):
258+
for idx, p in enumerate(pipeline, done_pipeline + 1):
250259
assert isinstance(p, Transform)
251260
matrix = p(matrix)
252-
transform_template_ = transform_template(p.__class__.__name__, p._args, p._kwargs)
253-
print('-> Stage {}, {}:\n{}\n'.format(idx, transform_template_, matrix))
261+
transform_template = repr(p)
262+
print('-> Stage {}, {}:\n{}\n'.format(idx, transform_template, matrix))
254263
else:
255264
for p in pipeline:
256265
assert isinstance(p, Transform)
@@ -263,6 +272,17 @@ def copy(matrix: Matrix, causal=True):
263272

264273
# Manual operation
265274
if causal:
266-
new_matrix._elementary_tape = matrix._elementary_tape
267-
new_matrix._elementary_hist = matrix._elementary_hist
268-
return new_matrix
275+
new_matrix.set_elementary(*matrix.get_elementary())
276+
new_matrix.set_transform_tape(*matrix.get_transform_tape())
277+
return new_matrix
278+
279+
280+
class no_tape:
281+
def __enter__(self):
282+
from matrixstitcher.transform import Transform
283+
self.prev = Transform.is_tape_enabled()
284+
Transform.set_tape_enabled(False)
285+
286+
def __exit__(self, exc_type, exc_val, exc_tb):
287+
from matrixstitcher.transform import Transform
288+
Transform.set_tape_enabled(self.prev)

0 commit comments

Comments
 (0)