diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a7361f7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*~ +__pycache__ +.DS_Store diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..9139047 --- /dev/null +++ b/__init__.py @@ -0,0 +1 @@ +import adlib as adlib diff --git a/adlib/__init__.py b/adlib/__init__.py new file mode 100644 index 0000000..985198f --- /dev/null +++ b/adlib/__init__.py @@ -0,0 +1,3 @@ +from .svd import SVD +from .eigh import EigenSolver +from .qr import QR diff --git a/adlib/eigh.py b/adlib/eigh.py new file mode 100644 index 0000000..795bbb3 --- /dev/null +++ b/adlib/eigh.py @@ -0,0 +1,40 @@ +import numpy as np +import torch + +class EigenSolver(torch.autograd.Function): + @staticmethod + def forward(self, A): + w, v = torch.symeig(A, eigenvectors=True) + + self.save_for_backward(w, v) + return w, v + + @staticmethod + def backward(self, dw, dv): + ''' + https://j-towns.github.io/papers/svd-derivative.pdf + + ''' + w, v = self.saved_tensors + dtype, device = w.dtype, w.device + N = v.shape[0] + + F = w - w[:,None] + F.diagonal().fill_(np.inf); + F = 1./F + + vt = v.t() + vdv = vt@dv + + return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt + +def test_eigs(): + M = 2 + torch.manual_seed(42) + A = torch.rand(M, M, dtype=torch.float64) + A = torch.nn.Parameter(A+A.t()) + assert(torch.autograd.gradcheck(DominantEigensolver.apply, A, eps=1e-6, atol=1e-4)) + print("Test Pass!") + +if __name__=='__main__': + test_eigs() diff --git a/adlib/power.py b/adlib/power.py new file mode 100644 index 0000000..5494df4 --- /dev/null +++ b/adlib/power.py @@ -0,0 +1,70 @@ +import torch +from torch.utils.checkpoint import detach_variable + +def step(A, x): + y = A@x + y = y[0].sign() * y + return y/y.norm() + +class FixedPoint(torch.autograd.Function): + @staticmethod + def forward(ctx, A, x0, tol): + x, x_prev = step(A, x0), x0 + while torch.dist(x, x_prev) > tol: + x, x_prev = step(A, x), x + ctx.save_for_backward(A, x) + ctx.tol = tol + return x + + @staticmethod + def backward(ctx, grad): + A, x = detach_variable(ctx.saved_tensors) + dA = grad + while True: + with torch.enable_grad(): + grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0] + if (torch.norm(grad) > ctx.tol): + dA = dA + grad + else: + break + with torch.enable_grad(): + dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0] + return dA, None, None + +def test_backward(): + N = 4 + torch.manual_seed(2) + A = torch.rand(N, N, dtype=torch.float64, requires_grad=True) + x0 = torch.rand(N, dtype=torch.float64) + x0 = x0/x0.norm() + tol = 1E-10 + + input = A, x0, tol + assert(torch.autograd.gradcheck(FixedPoint.apply, input, eps=1E-6, atol=tol)) + + print("Backward Test Pass!") + +def test_forward(): + torch.manual_seed(42) + N = 100 + tol = 1E-8 + dtype = torch.float64 + A = torch.randn(N, N, dtype=dtype) + A = A+A.t() + + w, v = torch.symeig(A, eigenvectors=True) + idx = torch.argmax(w.abs()) + + v_exact = v[:, idx] + v_exact = v_exact[0].sign() * v_exact + + x0 = torch.rand(N, dtype=dtype) + x0 = x0/x0.norm() + x = FixedPoint.apply(A, x0, tol) + + assert(torch.allclose(v_exact, x, rtol=tol, atol=tol)) + print("Forward Test Pass!") + +if __name__=='__main__': + test_forward() + test_backward() diff --git a/adlib/qr.py b/adlib/qr.py new file mode 100644 index 0000000..40f2320 --- /dev/null +++ b/adlib/qr.py @@ -0,0 +1,53 @@ +import torch + +class QR(torch.autograd.Function): + @staticmethod + def forward(self, A): + Q, R = torch.qr(A) + self.save_for_backward(A, Q, R) + return Q, R + + @staticmethod + def backward(self, dq, dr): + A, q, r = self.saved_tensors + if r.shape[0] == r.shape[1]: + return _simple_qr_backward(q, r, dq ,dr) + M, N = r.shape + B = A[:,M:] + dU = dr[:,:M] + dD = dr[:,M:] + U = r[:,:M] + da = _simple_qr_backward(q, U, dq+B@dD.t(), dU) + db = q@dD + return torch.cat([da, db], 1) + +def _simple_qr_backward(q, r, dq, dr): + if r.shape[-2] != r.shape[-1]: + raise NotImplementedError("QrGrad not implemented when ncols > nrows " + "or full_matrices is true and ncols != nrows.") + + qdq = q.t() @ dq + qdq_ = qdq - qdq.t() + rdr = r @ dr.t() + rdr_ = rdr - rdr.t() + tril = torch.tril(qdq_ + rdr_) + + def _TriangularSolve(x, r): + """Equiv to x @ torch.inverse(r).t() if r is upper-tri.""" + res = torch.trtrs(x.t(), r, upper=True, transpose=False)[0].t() + return res + + grad_a = q @ (dr + _TriangularSolve(tril, r)) + grad_b = _TriangularSolve(dq - q @ qdq, r) + return grad_a + grad_b + +def test_qr(): + M, N = 4, 6 + torch.manual_seed(2) + A = torch.randn(M, N) + A.requires_grad=True + assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2)) + print("Test Pass!") + +if __name__ == "__main__": + test_qr() diff --git a/adlib/svd.py b/adlib/svd.py new file mode 100644 index 0000000..c07061a --- /dev/null +++ b/adlib/svd.py @@ -0,0 +1,61 @@ +import numpy as np +import scipy.linalg +import torch, pdb + +def safe_inverse(x, epsilon=1E-12): + return x/(x**2 + epsilon) + +class SVD(torch.autograd.Function): + @staticmethod + def forward(self, A): + U, S, V = torch.svd(A) + #numpy_input = A.detach().numpy() + #U, S, Vt = scipy.linalg.svd(numpy_input, full_matrices=False, lapack_driver='gesvd') + #U = torch.as_tensor(U, dtype=A.dtype, device=A.device) + #S = torch.as_tensor(S, dtype=A.dtype, device=A.device) + #V = torch.as_tensor(np.transpose(Vt), dtype=A.dtype, device=A.device) + self.save_for_backward(U, S, V) + return U, S, V + + @staticmethod + def backward(self, dU, dS, dV): + U, S, V = self.saved_tensors + Vt = V.t() + Ut = U.t() + M = U.size(0) + N = V.size(0) + NS = len(S) + + F = (S - S[:, None]) + F = safe_inverse(F) + F.diagonal().fill_(0) + + G = (S + S[:, None]) + G.diagonal().fill_(np.inf) + G = 1/G + + UdU = Ut @ dU + VdV = Vt @ dV + + Su = (F+G)*(UdU-UdU.t())/2 + Sv = (F-G)*(VdV-VdV.t())/2 + + dA = U @ (Su + Sv + torch.diag(dS)) @ Vt + if (M>NS): + dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt + if (N>NS): + dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt) + #print (dU.norm().item(), dS.norm().item(), dV.norm().item()) + #print (Su.norm().item(), Sv.norm().item(), dS.norm().item()) + #print (dA1.norm().item(), dA2.norm().item(), dA3.norm().item()) + return dA + +def test_svd(): + M, N = 50, 40 + torch.manual_seed(2) + input = torch.rand(M, N, dtype=torch.float64, requires_grad=True) + assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4)) + print("Test Pass!") + +if __name__=='__main__': + test_svd() diff --git a/args.py b/args.py new file mode 100644 index 0000000..db40fea --- /dev/null +++ b/args.py @@ -0,0 +1,21 @@ +import argparse + +parser = argparse.ArgumentParser(description='') +parser.add_argument("-folder", default='../data/',help="where to store results") +parser.add_argument("-d", type=int, default=2, help="d") +parser.add_argument("-D", type=int, default=2, help="D") +parser.add_argument("-chi", type=int, default=20, help="chi") +parser.add_argument("-Nepochs", type=int, default=50, help="Nepochs") +parser.add_argument("-Maxiter", type=int, default=50, help="Maxiter") +parser.add_argument("-Jz", type=float, default=1.0, help="Jz") +parser.add_argument("-Jxy", type=float, default=1.0, help="Jxy") +parser.add_argument("-hx", type=float, default=1.0, help="hx") +parser.add_argument("-model", default='Heisenberg', choices=['TFIM', 'Heisenberg'], help="model name") +parser.add_argument("-load", default=None, help="load") +parser.add_argument("-save_period", type=int, default=1, help="") +parser.add_argument("-float32", action='store_true', help="use float32") +parser.add_argument("-use_checkpoint", action='store_true', help="use checkpoint") +parser.add_argument("-cuda", type=int, default=-1, help="use GPU") + +args = parser.parse_args() + diff --git a/ctmrg.py b/ctmrg.py new file mode 100644 index 0000000..ef9548c --- /dev/null +++ b/ctmrg.py @@ -0,0 +1,51 @@ +import torch +from renormalize import renormalize +from torch.utils import checkpoint + +def ctmrg(T, chi, max_iter, use_checkpoint=False): + + threshold = 1E-12 if T.dtype is torch.float64 else 1E-6 # ctmrg convergence threshold + + C = T.sum((0,1)) + E = T.sum(1) + + truncation_error = 0.0 + diff = 1E10 + sold = torch.zeros(chi, dtype=T.dtype, device=T.device) + for n in range(max_iter): + tensors = T, C, E + if use_checkpoint: # use checkpoint to save memory + C, E, s, error = checkpoint(*tensors) + else: + C, E, s, error = renormalize(*tensors) + + truncation_error = max(truncation_error, error.item()) + if (s.numel() == sold.numel()): + diff = (s-sold).norm().item() + #print( 'n: %d, error: %e, diff: %e' % (n, error.item(), diff) ) + + if (diff < threshold): + break + sold = s + print ('ctmrg iterations, diff, error', n, diff, truncation_error/n) + + return C, E + +if __name__=='__main__': + import time + D = 64 + chi = 150 + device = 'cuda:0' + T = torch.randn(D, D, D, D, dtype=torch.float64, device=device, requires_grad=True) + T = (T + T.permute(3, 1, 2, 0))/2. + T = (T + T.permute(0, 2, 1, 3))/2. + T = (T + T.permute(2, 3, 0, 1))/2. + T = (T + T.permute(1, 0, 3, 2))/2. + T = T/T.norm() + + C = torch.randn(chi, chi, dtype=torch.float64, device=device, requires_grad=True) + C = (C+C.t())/2. + E = torch.randn(chi, D, chi, dtype=torch.float64, device=device, requires_grad=True) + E = (E + E.permute(2, 1, 0))/2. + args = C, E, T, torch.tensor(chi) + checkpoint(renormalize, *args) diff --git a/ipeps.py b/ipeps.py new file mode 100644 index 0000000..c871031 --- /dev/null +++ b/ipeps.py @@ -0,0 +1,26 @@ +import torch +from ctmrg import ctmrg +from measure import get_obs +from utils import symmetrize +from args import args + +class iPEPS(torch.nn.Module): + def __init__(self, args, dtype=torch.float64, device='cpu', use_checkpoint=False): + super(iPEPS, self).__init__() + + B = torch.rand(args.d, args.D, args.D, args.D, args.D, dtype=dtype, device=device) + B = B/B.norm() + self.A = torch.nn.Parameter(B) + + def forward(self, H, Mpx, Mpy, Mpz): + + Asymm = symmetrize(self.A) + + d, D = Asymm.shape[0], Asymm.shape[1] + T = (Asymm.view(d, -1).t()@Asymm.view(d, -1)).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2) + T = T/T.norm() + + C, E = ctmrg(T, args.chi, args.Maxiter, args.use_checkpoint) + loss, Mx, My, Mz = get_obs(Asymm, H, Mpx, Mpy, Mpz, C, E) + + return loss, Mx, My, Mz diff --git a/main.py b/main.py new file mode 100644 index 0000000..6be744b --- /dev/null +++ b/main.py @@ -0,0 +1,95 @@ +''' +Variational iPEPS with automatic differentiation and GPU support +''' +import io +import numpy as np +import torch +torch.set_num_threads(1) +torch.manual_seed(42) +import subprocess +from utils import kronecker_product as kron +from utils import save_checkpoint, load_checkpoint +from ipeps import iPEPS +from args import args + +if __name__=='__main__': + import time + device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda)) + dtype = torch.float32 if args.float32 else torch.float64 + print ('use', dtype) + + model = iPEPS(args, dtype, device, args.use_checkpoint) + optimizer = torch.optim.LBFGS(model.parameters(), max_iter=10) + + if args.load is not None: + try: + load_checkpoint(args.load, model, optimizer) + print('load model', args.load) + except FileNotFoundError: + print('not found:', args.load) + params = list(model.parameters()) + params = list(filter(lambda p: p.requires_grad, params)) + nparams = sum([np.prod(p.size()) for p in params]) + print ('total nubmer of trainable parameters:', nparams) + + key = args.folder + key += args.model \ + + '_D' + str(args.D) \ + + '_chi' + str(args.chi) + if (args.float32): + key += '_float32' + cmd = ['mkdir', '-p', key] + subprocess.check_call(cmd) + + if args.model == 'TFIM': + + sx = torch.tensor([[0, 1], [1, 0]], dtype=dtype, device=device) + sy = torch.tensor([[0, -1], [1, 0]], dtype=dtype, device=device) + sz = torch.tensor([[1, 0], [0, -1]], dtype=dtype, device=device) + id2 = torch.tensor([[1, 0], [0, 1]], dtype=dtype, device=device) + + H = - 2*args.Jz*kron(sz,sz)-args.hx*(kron(sx,id2)+kron(id2,sx))/2 + Mpx = (kron(id2,sx) + kron(sx,id2))/2 + Mpy = (kron(id2,sy) + kron(sy,id2))/2 + Mpz = (kron(id2,sz) + kron(sz,id2))/2 + + elif args.model == 'Heisenberg': + #Hamiltonian operators on a bond + sx = torch.tensor([[0, 1], [1, 0]], dtype=dtype, device=device)*0.5 + sy = torch.tensor([[0, -1], [1, 0]], dtype=dtype, device=device)*0.5 + sp = torch.tensor([[0, 1], [0, 0]], dtype=dtype, device=device) + sm = torch.tensor([[0, 0], [1, 0]], dtype=dtype, device=device) + sz = torch.tensor([[1, 0], [0, -1]], dtype=dtype, device=device)*0.5 + id2 = torch.tensor([[1, 0], [0, 1]], dtype=dtype, device=device) + + # now assuming Jz>0, Jxy > 0 + H = 2*args.Jz*kron(sz,4*sx@sz@sx)-args.Jxy*(kron(sm, 4*sx@sp@sx)+kron(sp,4*sx@sm@sx)) + Mpx = kron(sx, id2) + Mpy = kron(sy, id2) + Mpz = kron(sz, id2) + print (H) + else: + print ('what model???') + sys.exit(1) + + def closure(): + optimizer.zero_grad() + start = time.time() + loss, Mx, My, Mz = model.forward(H, Mpx, Mpy, Mpz) + forward = time.time() + loss.backward() + print (model.A.norm().item(), model.A.grad.norm().item(), loss.item(), Mx.item(), My.item(), Mz.item(), torch.sqrt(Mx**2+My**2+Mz**2).item(), forward-start, time.time()-forward) + return loss + + with io.open(key+'.log', 'a', buffering=1, newline='\n') as logfile: + for epoch in range(args.Nepochs): + loss = optimizer.step(closure) + if (epoch%args.save_period==0): + save_checkpoint(key+'/peps.tensor'.format(epoch), model, optimizer) + + with torch.no_grad(): + En, Mx, My, Mz = model.forward(H, Mpx, Mpy, Mpz) + Mg = torch.sqrt(Mx**2+My**2+Mz**2) + message = ('{} ' + 5*'{:.16f} ').format(epoch, En, Mx, My, Mz, Mg) + print ('epoch, En, Mx, My, Mz, Mg', message) + logfile.write(message + u'\n') diff --git a/measure.py b/measure.py new file mode 100644 index 0000000..62a2362 --- /dev/null +++ b/measure.py @@ -0,0 +1,26 @@ +import torch + +def get_obs(Asymm, H, Sx, Sy, Sz, C, E ): + # A(phy,u,l,d,r) + + Da = Asymm.size() + Td = torch.einsum('mxyzw,nabcd->xaybzcwdmn',(Asymm,Asymm)).contiguous().view(Da[1]**2, Da[2]**2, Da[3]**2, Da[4]**2, Da[0], Da[0]) + + CE = torch.tensordot(C,E,([1],[0])) # C(1d)E(dga)->CE(1ga) + EL = torch.tensordot(E,CE,([2],[0])) # E(2e1)CE(1ga)->EL(2ega) + EL = torch.tensordot(EL,Td,([1,2],[1,0])) # EL(2ega)T(gehbmn)->EL(2ahbmn) + EL = torch.tensordot(EL,CE,([0,2],[0,1])) # EL(2ahbmn)CE(2hc)->EL(abmnc)=EL(12mn3) + Rho = torch.tensordot(EL,EL,([0,1,4],[0,1,4])).permute(0,2,1,3).contiguous().view(Da[0]**2,Da[0]**2) + + # print( (Rho-Rho.t()).norm() ) + Rho = 0.5*(Rho + Rho.t()) + + Tnorm = Rho.trace() + Energy = torch.mm(Rho,H).trace()/Tnorm + Mx = torch.mm(Rho,Sx).trace()/Tnorm + My = torch.mm(Rho,Sy).trace()/Tnorm + Mz = torch.mm(Rho,Sz).trace()/Tnorm + + #print("Tnorm = %g, Energy = %g " % (Tnorm.item(), Energy.item()) ) + + return Energy, Mx, My, Mz diff --git a/renormalize.py b/renormalize.py new file mode 100644 index 0000000..6d34425 --- /dev/null +++ b/renormalize.py @@ -0,0 +1,50 @@ +import torch +from adlib import EigenSolver +symeig = EigenSolver.apply +from args import args + +def renormalize(*tensors): + T, C, E = tensors + + D, d = E.shape[0], E.shape[1] + # M = torch.einsum('ab,eca,bdg,cdfh->efgh', (C, E, E, T)).contiguous().view(D*d, D*d) + M = torch.tensordot(E,C,dims=1) # E(eca)*C(ab)=M(ecb) + M = torch.tensordot(M,E,dims=1) # M(ecb)*E(bdg)=M(ecdg) + M = torch.tensordot(M,T,dims=([1,2],[1,0])) # M(ecdg)*T(dcfh)=M(egfh) + M = M.permute(0,2,1,3).contiguous().view(D*d, D*d) # M(egfh)->M(ef;gh) + + M = (M+M.t())/2. + M = M/M.norm() + + D_new = min(D*d, args.chi) + if (not torch.isfinite(M).all()): + print ('M is not finite!!') + + #U, S, V = svd(M) + #truncation_error = S[D_new:].sum()/S.sum() + #P = U[:, :D_new] # projection operator + + #S, U = torch.symeig(M, eigenvectors=True) + S, U = symeig(M) + sorted, indices = torch.sort(S.abs(), descending=True) + truncation_error = sorted[D_new:].sum()/sorted.sum() + S = S[indices][:D_new] + P = U[:, indices][:, :D_new] # projection operator + + C = (P.t() @ M @ P) #(D, D) + C = (C+C.t())/2. + + ## EL(u,r,d) + P = P.view(D,d,D_new) + E = torch.tensordot(E, P, ([0],[0])) # E(dhf)P(dea)=E(hfea) + E = torch.tensordot(E, T, ([0,2],[1,0])) # E(hfea)T(ehgb)=E(fagb) + E = torch.tensordot(E, P, ([0,2],[0,1])) # E(fagb)P(fgc)=E(abc) + + #ET = torch.einsum('ldr,adbc->labrc', (E, T)).contiguous().view(D*d, d, D*d) + #ET = torch.tensordot(E, T, dims=([1], [1])) + #ET = ET.permute(0, 2, 3, 1, 4).contiguous().view(D*d, d, D*d) + #E = torch.einsum('li,ldr,rj->idj', (P, ET, P)) #(D_new, d, D_new) + + E = (E + E.permute(2, 1, 0))/2. + + return C/C.norm(), E/E.norm(), S/S.max(), truncation_error diff --git a/trg.py b/trg.py new file mode 100644 index 0000000..53b4d0e --- /dev/null +++ b/trg.py @@ -0,0 +1,71 @@ +import torch +from adlib import SVD +svd = SVD.apply + +def TRG(K, Dcut, no_iter, device='cpu', epsilon=1E-15): + D = 2 + + lam = [torch.cosh(K)*np.sqrt(2), torch.sinh(K)*np.sqrt(2)] + T = [] + for i in range(D): + for j in range(D): + for k in range(D): + for l in range(D): + if ((i+j-k-l)%2==0): + T.append(torch.sqrt(lam[i]*lam[j]*lam[k]*lam[l])) + else: + T.append(torch.tensor(0.0, dtype=K.dtype, device=K.device)) + T = torch.stack(T).view(D, D, D, D) + + lnZ = 0.0 + for n in range(no_iter): + + #print(n, " ", T.max(), " ", T.min()) + maxval = T.abs().max() + T = T/maxval + lnZ += 2**(no_iter-n)*torch.log(maxval) + + Ma = T.permute(2, 1, 0, 3).contiguous().view(D**2, D**2) + Mb = T.permute(3, 2, 1, 0).contiguous().view(D**2, D**2) + + Ua, Sa, Va = svd(Ma) + Ub, Sb, Vb = svd(Mb) + + D_new = min(min(D**2, Dcut), min((Sa>epsilon).sum().item(), (Sb>epsilon).sum().item())) + + S1 = (Ua[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new) + S3 = (Va[:, :D_new]* torch.sqrt(Sa[:D_new])).view(D, D, D_new) + S2 = (Ub[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new) + S4 = (Vb[:, :D_new]* torch.sqrt(Sb[:D_new])).view(D, D, D_new) + + T_new = torch.einsum('war,abu,bgl,gwd->ruld', (S1, S2, S3, S4)) + + D = D_new + T = T_new + + trace = 0.0 + for i in range(D): + trace += T[i, i, i, i] + lnZ += torch.log(trace) + + return lnZ + +if __name__=="__main__": + import numpy as np + import argparse + parser = argparse.ArgumentParser(description='') + parser.add_argument("-float32", action='store_true', help="use float32") + parser.add_argument("-cuda", type=int, default=-1, help="use GPU") + args = parser.parse_args() + device = torch.device("cpu" if args.cuda<0 else "cuda:"+str(args.cuda)) + dtype = torch.float32 if args.float32 else torch.float64 + + Dcut = 24 + n = 20 + + for K in np.linspace(0.4, 0.5, 11): + beta = torch.tensor(K, dtype=dtype, device=device).requires_grad_() + lnZ = TRG(beta, Dcut, n, device=device) + dlnZ, = torch.autograd.grad(lnZ, beta,create_graph=True) # En = -d lnZ / d beta + dlnZ2, = torch.autograd.grad(dlnZ, beta) # Cv = beta^2 * d^2 lnZ / d beta^2 + print (K, lnZ.item()/2**n, -dlnZ.item()/2**n, dlnZ2.item()*beta.item()**2/2**n) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..6763cf1 --- /dev/null +++ b/utils.py @@ -0,0 +1,68 @@ +import torch + +def kronecker_product(t1, t2): + """ + Computes the Kronecker product between two tensors. + See https://en.wikipedia.org/wiki/Kronecker_product + """ + t1_height, t1_width = t1.size() + t2_height, t2_width = t2.size() + out_height = t1_height * t2_height + out_width = t1_width * t2_width + + tiled_t2 = t2.repeat(t1_height, t1_width) + expanded_t1 = ( + t1.unsqueeze(2) + .unsqueeze(3) + .repeat(1, t2_height, t2_width, 1) + .view(out_height, out_width) + ) + + return expanded_t1 * tiled_t2 + +def symmetrize(A): + ''' + A(phy,u,l,d,r) + left-right, up-down, diagonal symmetrize + ''' + Asymm = (A + A.permute(0, 1, 4, 3, 2))/2. + Asymm = (Asymm + Asymm.permute(0, 3, 2, 1, 4))/2. + Asymm = (Asymm + Asymm.permute(0, 2, 1, 4, 3))/2. + Asymm = (Asymm + Asymm.permute(0, 4, 3, 2, 1))/2. + + return Asymm/Asymm.norm() + +def save_checkpoint(checkpoint_path, model, optimizer): + state = {'state_dict': model.state_dict(), + 'optimizer' : optimizer.state_dict()} + #print(model.state_dict().keys()) + torch.save(state, checkpoint_path) + print('model saved to %s' % checkpoint_path) + +def load_checkpoint(checkpoint_path, model, optimizer): + state = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + Dold = state['state_dict']['A'].shape[1] + if (Dold != model.D): + A = 0.01*torch.randn(model.d, model.D, model.D, model.D, model.D, dtype=model.A.dtype, device=model.A.device) # some pertubation + A[:, :Dold, :Dold, :Dold, :Dold] = state['state_dict']['A'] + state['state_dict']['A'] = A + else: # since we changed D, have to reinitialize optimizer + optimizer.load_state_dict(state['optimizer']) + #print (state['state_dict']) + model.load_state_dict(state['state_dict']) + print('model loaded from %s' % checkpoint_path) + +if __name__=='__main__': + import torch + A = torch.arange(4).view(2,2) + B = torch.arange(9).view(3,3) + print (A) + print (B) + print (kronecker_product(A, B)) + + A = torch.randn(2,4,4,4,4) + print ('A', A) + Asymm = symmetrize(A) + print ('A', A) + print ('Asymm', Asymm) +