Skip to content

Commit

Permalink
some codes for release
Browse files Browse the repository at this point in the history
  • Loading branch information
wangleiphy committed Mar 4, 2019
0 parents commit f670238
Show file tree
Hide file tree
Showing 15 changed files with 639 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*~
__pycache__
.DS_Store
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import adlib as adlib
3 changes: 3 additions & 0 deletions adlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .svd import SVD
from .eigh import EigenSolver
from .qr import QR
40 changes: 40 additions & 0 deletions adlib/eigh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np
import torch

class EigenSolver(torch.autograd.Function):
@staticmethod
def forward(self, A):
w, v = torch.symeig(A, eigenvectors=True)

self.save_for_backward(w, v)
return w, v

@staticmethod
def backward(self, dw, dv):
'''
https://j-towns.github.io/papers/svd-derivative.pdf
'''
w, v = self.saved_tensors
dtype, device = w.dtype, w.device
N = v.shape[0]

F = w - w[:,None]
F.diagonal().fill_(np.inf);
F = 1./F

vt = v.t()
vdv = vt@dv

return v@(torch.diag(dw) + F*(vdv-vdv.t())/2) @vt

def test_eigs():
M = 2
torch.manual_seed(42)
A = torch.rand(M, M, dtype=torch.float64)
A = torch.nn.Parameter(A+A.t())
assert(torch.autograd.gradcheck(DominantEigensolver.apply, A, eps=1e-6, atol=1e-4))
print("Test Pass!")

if __name__=='__main__':
test_eigs()
70 changes: 70 additions & 0 deletions adlib/power.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from torch.utils.checkpoint import detach_variable

def step(A, x):
y = A@x
y = y[0].sign() * y
return y/y.norm()

class FixedPoint(torch.autograd.Function):
@staticmethod
def forward(ctx, A, x0, tol):
x, x_prev = step(A, x0), x0
while torch.dist(x, x_prev) > tol:
x, x_prev = step(A, x), x
ctx.save_for_backward(A, x)
ctx.tol = tol
return x

@staticmethod
def backward(ctx, grad):
A, x = detach_variable(ctx.saved_tensors)
dA = grad
while True:
with torch.enable_grad():
grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0]
if (torch.norm(grad) > ctx.tol):
dA = dA + grad
else:
break
with torch.enable_grad():
dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0]
return dA, None, None

def test_backward():
N = 4
torch.manual_seed(2)
A = torch.rand(N, N, dtype=torch.float64, requires_grad=True)
x0 = torch.rand(N, dtype=torch.float64)
x0 = x0/x0.norm()
tol = 1E-10

input = A, x0, tol
assert(torch.autograd.gradcheck(FixedPoint.apply, input, eps=1E-6, atol=tol))

print("Backward Test Pass!")

def test_forward():
torch.manual_seed(42)
N = 100
tol = 1E-8
dtype = torch.float64
A = torch.randn(N, N, dtype=dtype)
A = A+A.t()

w, v = torch.symeig(A, eigenvectors=True)
idx = torch.argmax(w.abs())

v_exact = v[:, idx]
v_exact = v_exact[0].sign() * v_exact

x0 = torch.rand(N, dtype=dtype)
x0 = x0/x0.norm()
x = FixedPoint.apply(A, x0, tol)

assert(torch.allclose(v_exact, x, rtol=tol, atol=tol))
print("Forward Test Pass!")

if __name__=='__main__':
test_forward()
test_backward()
53 changes: 53 additions & 0 deletions adlib/qr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch

class QR(torch.autograd.Function):
@staticmethod
def forward(self, A):
Q, R = torch.qr(A)
self.save_for_backward(A, Q, R)
return Q, R

@staticmethod
def backward(self, dq, dr):
A, q, r = self.saved_tensors
if r.shape[0] == r.shape[1]:
return _simple_qr_backward(q, r, dq ,dr)
M, N = r.shape
B = A[:,M:]
dU = dr[:,:M]
dD = dr[:,M:]
U = r[:,:M]
da = _simple_qr_backward(q, U, dq+B@dD.t(), dU)
db = q@dD
return torch.cat([da, db], 1)

def _simple_qr_backward(q, r, dq, dr):
if r.shape[-2] != r.shape[-1]:
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")

qdq = q.t() @ dq
qdq_ = qdq - qdq.t()
rdr = r @ dr.t()
rdr_ = rdr - rdr.t()
tril = torch.tril(qdq_ + rdr_)

def _TriangularSolve(x, r):
"""Equiv to x @ torch.inverse(r).t() if r is upper-tri."""
res = torch.trtrs(x.t(), r, upper=True, transpose=False)[0].t()
return res

grad_a = q @ (dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - q @ qdq, r)
return grad_a + grad_b

def test_qr():
M, N = 4, 6
torch.manual_seed(2)
A = torch.randn(M, N)
A.requires_grad=True
assert(torch.autograd.gradcheck(QR.apply, A, eps=1e-4, atol=1e-2))
print("Test Pass!")

if __name__ == "__main__":
test_qr()
61 changes: 61 additions & 0 deletions adlib/svd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
import scipy.linalg
import torch, pdb

def safe_inverse(x, epsilon=1E-12):
return x/(x**2 + epsilon)

class SVD(torch.autograd.Function):
@staticmethod
def forward(self, A):
U, S, V = torch.svd(A)
#numpy_input = A.detach().numpy()
#U, S, Vt = scipy.linalg.svd(numpy_input, full_matrices=False, lapack_driver='gesvd')
#U = torch.as_tensor(U, dtype=A.dtype, device=A.device)
#S = torch.as_tensor(S, dtype=A.dtype, device=A.device)
#V = torch.as_tensor(np.transpose(Vt), dtype=A.dtype, device=A.device)
self.save_for_backward(U, S, V)
return U, S, V

@staticmethod
def backward(self, dU, dS, dV):
U, S, V = self.saved_tensors
Vt = V.t()
Ut = U.t()
M = U.size(0)
N = V.size(0)
NS = len(S)

F = (S - S[:, None])
F = safe_inverse(F)
F.diagonal().fill_(0)

G = (S + S[:, None])
G.diagonal().fill_(np.inf)
G = 1/G

UdU = Ut @ dU
VdV = Vt @ dV

Su = (F+G)*(UdU-UdU.t())/2
Sv = (F-G)*(VdV-VdV.t())/2

dA = U @ (Su + Sv + torch.diag(dS)) @ Vt
if (M>NS):
dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt
if (N>NS):
dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt)
#print (dU.norm().item(), dS.norm().item(), dV.norm().item())
#print (Su.norm().item(), Sv.norm().item(), dS.norm().item())
#print (dA1.norm().item(), dA2.norm().item(), dA3.norm().item())
return dA

def test_svd():
M, N = 50, 40
torch.manual_seed(2)
input = torch.rand(M, N, dtype=torch.float64, requires_grad=True)
assert(torch.autograd.gradcheck(SVD.apply, input, eps=1e-6, atol=1e-4))
print("Test Pass!")

if __name__=='__main__':
test_svd()
21 changes: 21 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import argparse

parser = argparse.ArgumentParser(description='')
parser.add_argument("-folder", default='../data/',help="where to store results")
parser.add_argument("-d", type=int, default=2, help="d")
parser.add_argument("-D", type=int, default=2, help="D")
parser.add_argument("-chi", type=int, default=20, help="chi")
parser.add_argument("-Nepochs", type=int, default=50, help="Nepochs")
parser.add_argument("-Maxiter", type=int, default=50, help="Maxiter")
parser.add_argument("-Jz", type=float, default=1.0, help="Jz")
parser.add_argument("-Jxy", type=float, default=1.0, help="Jxy")
parser.add_argument("-hx", type=float, default=1.0, help="hx")
parser.add_argument("-model", default='Heisenberg', choices=['TFIM', 'Heisenberg'], help="model name")
parser.add_argument("-load", default=None, help="load")
parser.add_argument("-save_period", type=int, default=1, help="")
parser.add_argument("-float32", action='store_true', help="use float32")
parser.add_argument("-use_checkpoint", action='store_true', help="use checkpoint")
parser.add_argument("-cuda", type=int, default=-1, help="use GPU")

args = parser.parse_args()

51 changes: 51 additions & 0 deletions ctmrg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from renormalize import renormalize
from torch.utils import checkpoint

def ctmrg(T, chi, max_iter, use_checkpoint=False):

threshold = 1E-12 if T.dtype is torch.float64 else 1E-6 # ctmrg convergence threshold

C = T.sum((0,1))
E = T.sum(1)

truncation_error = 0.0
diff = 1E10
sold = torch.zeros(chi, dtype=T.dtype, device=T.device)
for n in range(max_iter):
tensors = T, C, E
if use_checkpoint: # use checkpoint to save memory
C, E, s, error = checkpoint(*tensors)
else:
C, E, s, error = renormalize(*tensors)

truncation_error = max(truncation_error, error.item())
if (s.numel() == sold.numel()):
diff = (s-sold).norm().item()
#print( 'n: %d, error: %e, diff: %e' % (n, error.item(), diff) )

if (diff < threshold):
break
sold = s
print ('ctmrg iterations, diff, error', n, diff, truncation_error/n)

return C, E

if __name__=='__main__':
import time
D = 64
chi = 150
device = 'cuda:0'
T = torch.randn(D, D, D, D, dtype=torch.float64, device=device, requires_grad=True)
T = (T + T.permute(3, 1, 2, 0))/2.
T = (T + T.permute(0, 2, 1, 3))/2.
T = (T + T.permute(2, 3, 0, 1))/2.
T = (T + T.permute(1, 0, 3, 2))/2.
T = T/T.norm()

C = torch.randn(chi, chi, dtype=torch.float64, device=device, requires_grad=True)
C = (C+C.t())/2.
E = torch.randn(chi, D, chi, dtype=torch.float64, device=device, requires_grad=True)
E = (E + E.permute(2, 1, 0))/2.
args = C, E, T, torch.tensor(chi)
checkpoint(renormalize, *args)
26 changes: 26 additions & 0 deletions ipeps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
from ctmrg import ctmrg
from measure import get_obs
from utils import symmetrize
from args import args

class iPEPS(torch.nn.Module):
def __init__(self, args, dtype=torch.float64, device='cpu', use_checkpoint=False):
super(iPEPS, self).__init__()

B = torch.rand(args.d, args.D, args.D, args.D, args.D, dtype=dtype, device=device)
B = B/B.norm()
self.A = torch.nn.Parameter(B)

def forward(self, H, Mpx, Mpy, Mpz):

Asymm = symmetrize(self.A)

d, D = Asymm.shape[0], Asymm.shape[1]
T = (Asymm.view(d, -1).t()@Asymm.view(d, -1)).view(D, D, D, D, D, D, D, D).permute(0,4, 1,5, 2,6, 3,7).contiguous().view(D**2, D**2, D**2, D**2)
T = T/T.norm()

C, E = ctmrg(T, args.chi, args.Maxiter, args.use_checkpoint)
loss, Mx, My, Mz = get_obs(Asymm, H, Mpx, Mpy, Mpz, C, E)

return loss, Mx, My, Mz
Loading

0 comments on commit f670238

Please sign in to comment.