Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit fb27805

Browse files
mganahlcoryellChase Robertsamilstedsummer-bebop
authored
eig, eigh, svd, qr (#461)
* started implementing block-sparse tensors * removed files * working on AbelianIndex * working in block sparisty * added reshape and lots of other stuff * added Index, an index type for symmetric tensors * added small tutorial * added docstring * fixed bug in retrieve_diagonal_blocks * TODO added * improved initialization a bit * more efficient initialization * just formatting * added random * added fuse_degeneracies * fixed bug in reshape * dosctring, typing * removed TODO * removed confusing code line * bug removed * comment * added __mul__ to Index * added sparse_shape and updated reshape to accept both int and Index lists * more in tutorial * comment * added new test function * testing function hacking * docstring * small speed up * Remove gui directory (migrated to another repo) (#399) * a slightly more elegant code * use one more np function * removed some crazy slow code * faster code * Update README.md (#404) * add return_data * doc * bug fix * a little faster * substantial speedup * renaming * removed todo * some comments * comments * fixed some bug in reshape * comments * default value changed * fixed bug, old version is now faster again * cleaned up reshape * started adding tests * Quantum abstractions (#379) * Initial attempt at quantum classes. * .tensor() and QuantumIdentity. * check_hilberts -> check_spaces * Add some blurb. * Rename to Qu. * Finish Qu. * Fix matmul in case operators share network components. * Add some scalar methods. * Improve a docstring. * Redo scalars and make identity work using copy tensors. A QuOperator can now have a scalar component (disconnected scalar subnetwork). Also introduce `ignore_edges` for power users. * Remove obsolete parts of QuScalar. * Add contraction stuff. * Add from_tensor() constructors. * Doctstring. * Doc/comments fixes. * Add typing. * Remove some lint. * Fix a bug. * Add very simple constructor tests. * Default edge ordering for eval(). Better docstrings. * A bunch more tests. * tensor_prod -> tensor_product, outer_product * .is_scalar -> is_scalar() etc. * Improve docstrings on axis ordering. * Improve and fix scalar multiplication. * Kill outer_product(). * CopyNode needs a backend and dtype. * Fix __mul__ and add __rmul__. * More docstrings and add axis arguments to vector from_tensor()s. * Add backends to tests. * Delint. * CopyNode should not inflate its tensor just to tell you the dtype. * Correct two docstrings. * Improve some tests. Particulary, test identity some more, since we now try to be efficient with CopyNode identity tensors. * Treat CopyNode identity tensors efficiently. Also rename shape -> space. * Add support for copying CopyNodes. * Propagate output edges properly. * Test that CopyNodes are propagated. Also do a CopyNode sandwich test. * Improve typing. Also more shape -> space. * adding random uniform initialization (#412) * adding random uniform initialization * fixes dumb pylint * couple of nit picks * replace kron with broadcasting * column-major -> row-major * documentation * added function to compute unique charges and charge degeneracies Function avoids explicit full fusion of all legs, and instead only keeps track of the unique charges and their degeneracies upon fusion * improved block finding, fixed bug in reshape re-intorduced BlockSparseTensor.dense_shape new method for fusing charges and degeneracies (faster for very rectangular matrices) * fuse_charge_pair added fuse_charges added * use is_leave * new tests * removed TODO, BlockSparseTensor.shape returns ref instead of copy * added tests * added tests * column-major -> row-major forgot to fix fusing order of charges and degeneracies * fix broken tests * test added * mostly docstring * docstring * added map_to_integer * test for map_to_integer * added functions to find sparse positions when fusing two charges * renaming of routines * added unfuse * test unfuse * fixed bug in the new routine for finding diagonal blocks * test added * docstring * added tests * renaming * tests * transpose added, map_to_integer removed (used np.unravel_index and np.ravel_multi_index instead) * find_dense_positions made faster * working on transpose * transpose modified * Index.name -> property * added charge types * adding tests * fixing bugs * implementing more symmetries * typo + remove cython lookup * split charge.py from index.py * tests for charge.py * test added * added matmul * added test for matmul * tests + allow np.int8 * typo * undo typo * test\ * savety commit, starting to add multiple charges * Charge -> ChargeCollection * removed offsets from U1Charge (unnecessary), Charge -> ChargeCollection * new tests * tests for new index * new Index class * new block tensor * shorter code * add __eq__, remove nonzero, add unique * working on charges.py * fix small bug in BaseCharge.__init__ * fix tests after bugfix * tests for equals() and __eq__ * added equals() for comparing with unshifted target charges __eq__ now only compares shifted target charges * added typing * ChargeCollection.__repr__ modified * *** empty log message *** * this commit is not working * fix bug in __len__ fix various bugs in __eq__ and __getitem__ * working in implemetation of multiple charges * bugfix in ChargeCollection.__getitem__ * adding tests * sleep commit * added iterators * ChargeCollection.__init__: charges are now always stacked, self.charges contain views to the stacked charges __init__ can be called with optional shifts and stacked_charges to initialize the BaseCharges object with it * lunch commit * back from lunch * tests added * ported find_sparse_positions and find_diagonal_sparse_blocks to new charge interface * broken commit * fixed bug in find_dense_positions * fix bug in find_dense_positions * docstring * fix bug in Index initialization * fix bug in Index initialization * typo * remove __rmul__ calls of ChargeCollection and BaseCharge * removed __rmul__ * removed some bugs inb transpose * broken commit * broken commit * remove csr matrix, use search sorted * remove unfuse, use divmod * broken commit, working on tensordot * tensordot implemented, not tested * removed commented codex * fix tests * fix tests * added test for BlockSparseTensor back * renaming files * fix tutorial, fix import * faster find_dense_positions * compute reduced svd in `backends.numpy.decompositions.svd_decompostion` (#420) * compute reduced svd when calling np.linalg.svd from numpy backend * test SVD when max_singular_values>bond_dimension (numpy backend) * added intersect to BaseCharge * broken commmit (Apple sucks big time) * broken commit * broken commit * broken commit * Fixes for contract_between(). (#421) * Fixes for contract_between(). * output_edge_ordering was not respected in trace or outer_product cases. * axis names are now applied *after* edge reordering * added some docstring to clarify ordering * avoid a warning when contracting all edges of one or more of the input tensors. * Split out ordering tests. Also improves the basic contract_between() test so that it outputs a non-symmetric matrix (now rectangular). * broken commit * broken * added `return_indices` to intersect * faster transpose + tensordot implementation * Update requirements_travis.txt (#426) * rewrote find_dense_positions to take multipe charges avoids a for loop in _find_diagonal_dense_blocks and speeds up the code * find_sparse_positions update to be a tiny bit faster * Remove duplicate Dockerfile from root directory (#431) * BaseNode / Edge class name type check protection add (#424) * BaseNode / Edge class text input protection added (#423) BaseNode class - Add protection to name, axis_names *Protected in 3 place *Initialize stage - __init__ *Function use setting - set_name / add_axis_names *Property - Add @Property to name to protect direct adding node.name = 123 Edge class - Add protection to name *Protected in 3 place *Initialize stage - __init__ *Function use setting - set_name *Property * BaseNode / Edge class text input protection code revise (#423) *if type(name) != str *if not isinstance(name, str) *change using type to isinstance to follow pylint Co-authored-by: Chase Roberts <keeper6928@gmail.com> * fix bug in _get_diagonal_dense_blocks * fix bug * fixed bug in transpose * Test network operations (#441) * added test for mps switch backend * added switch backend method to MPS * added test for network operations switch backend * make sure switch_backend not only fixes tensor but also node property * added switch_backend to init * added missing tests for network operations * some linting issues * Rename backend shape methods (#355) (#436) concat function * rename from cocate to shape_concat shape function * rename from shape to shape_tensor prod function * rename from prod to shape_prod * function name is duplicated in shell_backend.py * rename existing shape_prod function to shape_product * need to change the name later Co-authored-by: Chase Roberts <keeper6928@gmail.com> * fixed final_order passing for tensordot * Added SAT Tutorial (#438) * Add files via upload Added SAT Tutorials * Update SATTutorial.ipynb * Update SATTutorial.ipynb * Update SATTutorial.ipynb * Update SATTutorial.ipynb * License changed * Created using Colaboratory Co-authored-by: Chase Roberts <chaseriley@google.com> * More Test! (#444) * added test for mps switch backend * added switch backend method to MPS * added test for network operations switch backend * make sure switch_backend not only fixes tensor but also node property * added switch_backend to init * added a lot of tests for network components * a lot more tests * some more tests * some linter things * added test base class instead of hack * disabled some pytype warnings * disabled some pylint warnings * Return empty dict for empty sequence input to MPS left_envs and right_envs (#440) * Return empty dict for empty input to MPS envs * Add tests for empty sequence input to MPS envs * Use explicit sequences for MPS envs tests Co-authored-by: Chase Roberts <chaseriley@google.com> * Issue #339. with tn.DefaultBackend(backend): support (#434) * A context manager support implementation for setting up a backend for Nodes. (Issue #339) * Stack-based backend context manager implementation * code styele fix * Added get_default_backend() function which returns top stack backend. Stack returns config.default_backend if there is nothing in stack. A little clean-up in test file. * - Moved `set_default_backend` to the `backend_contextmanager` - `default_backend` now is a property of `_DefaultBackendStack` - removed `config` imports as an unused file. - fixed some tests in `backend_contextmanager_test.py` * little code-style fix Co-authored-by: Chase Roberts <chaseriley@google.com> * Algebraic operation add( + ), sub( - ), mul( * ), div( / ) for BaseNode class (#439) * BaseNode / Edge class text input protection added (#423) BaseNode class - Add protection to name, axis_names *Protected in 3 place *Initialize stage - __init__ *Function use setting - set_name / add_axis_names *Property - Add @Property to name to protect direct adding node.name = 123 Edge class - Add protection to name *Protected in 3 place *Initialize stage - __init__ *Function use setting - set_name *Property * BaseNode / Edge class text input protection code revise (#423) *if type(name) != str *if not isinstance(name, str) *change using type to isinstance to follow pylint * Algebraic operation add( + ), sub( - ), mul( * ), div( / ) for BaseNode class (#292) *[BaseNode class] - add / sub / mul / truediv NotImplemented function Added *[Node class] - add / sub / mul / truediv function added *[CopyNode class] - overload the BaseNode mul / truediv as NotImplemented *[basebackend] - add / sub / mul / div NotImplemented function added *[numpy / tensorflow / pytorch] - add / sub / mul / div function added *[shell] - add / sub / div NotImplemented function added *Testing files [network_components_free_test] * Exception - Tensorflow is not tested when the operand is scalar * 1. Check add / sub / mul / div with int / float / Node * 2. Check implicit conversion * 2. Check the Type Error when type is not int / float / Node * 3. Check is the operand backend same * 4. Check is BaseNode has attribute _tensor [backend_test - numpy / tensorflow / pytorch] *check add / sub / mul / divide work for int / float / Node * Add test cases for Tensorflow Algebraic operation and fix add, sub name (#292) [Change name] *add -> addition *subtract -> substraction [Add test case for Tensorflow] * Specify the datatype to resolve the conflict between different dtype operation [Test case for pytorch / jax] * pytorch - [int / int -> int] give different answer for torch when it is dividing two integer * jax - Different from other backend jax backend return 64bits dtype even operate between 32bits so put exceptional dtype test case for jax backend * Add test cases for Tensorflow Algebraic operation and fix add, sub name (#292) [Change name] *add -> addition *subtract -> substraction [Add test case for Tensorflow] * Specify the datatype to resolve the conflict between different dtype operation [Test case for pytorch / jax] * pytorch - [int / int -> int] give different answer for torch when it is dividing two integer * jax - Different from other backend jax backend return 64bits dtype even operate between 32bits so put exceptional dtype test case for jax backend * Add __add__, __sub__, __mul__, __truediv__ to TestNode Class Co-authored-by: Chase Roberts <chaseriley@google.com> * improved performance, but only u1 currently supported in this commit * Fix unsafe None checks (#449) * None checks added for constructors * Changes in None check and resolve comments * Backend test (#448) * added test for mps switch backend * added switch backend method to MPS * added test for network operations switch backend * make sure switch_backend not only fixes tensor but also node property * added switch_backend to init * missing test for backend contextmanager * notimplemented tests for base backend * added subtraction test notimplemented * added jax backend index_update test * first missing tests for numpy * actually catched an error in numpy_backend eigs method! * more eigs tests * didnt catch an error, unexpected convention * more tests for eigsh_lancszos * added missing pytorch backend tests * added missing tf backend tests * pytype * suppress pytype Co-authored-by: Chase Roberts <chaseriley@google.com> * Version bump for release * merging Glen's and my code * fix bug in unique * adding/removing tests * add benchmark file * adding files * deleted some files * fix bug * renaming and shortening * cleaning up code * cleaning up * removed _check_flows * remove a print * fix bug * nothing * fix big in flatten_meta_data * added bunch of tests * added inner and outer product * removed binary tree, switched to a list of charges * added copy() * bugfix * added __add__ __sub__ __mul__ __rmul__ * added __eq__ * change dosctring * added svd * better check * add proper compute_uv flag , remove artifact return values * added qr, eigh, eig * added tests * fix tests * fix index bugs Co-authored-by: Cutter Coryell <14116109+coryell@users.noreply.github.com> Co-authored-by: Chase Roberts <chaseriley@google.com> Co-authored-by: Ashley Milsted <ashmilsted@gmail.com> Co-authored-by: Ivan PANICO <iv.panico@gmail.com> Co-authored-by: Ori Alberton <github@oalberton.com> Co-authored-by: Kshithij Iyer <kshithij.ki@gmail.com> Co-authored-by: Hyunbyung, Park <hyunbyung87@gmail.com> Co-authored-by: MichaelMarien <marien.mich@gmail.com> Co-authored-by: kosehy <kosehy@gmail.com> Co-authored-by: Olga Okrut <46659064+olgOk@users.noreply.github.com> Co-authored-by: Aidan Dang <dang@aidan.gg> Co-authored-by: Tigran Katolikyan <43802339+katolikyan@users.noreply.github.com> Co-authored-by: Jayanth Chandra <jayanthchandra14@gmail.com>
1 parent 9a3959b commit fb27805

File tree

4 files changed

+311
-137
lines changed

4 files changed

+311
-137
lines changed

tensornetwork/block_tensor/block_tensor.py

Lines changed: 245 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,28 @@
1717
from __future__ import print_function
1818
import numpy as np
1919
from tensornetwork.backends import backend_factory
20-
from tensornetwork.block_tensor.index import Index, fuse_index_pair, split_index
20+
from tensornetwork.block_tensor.index import Index, fuse_index_pair
2121
from tensornetwork.block_tensor.charge import fuse_degeneracies, fuse_charges, fuse_degeneracies, BaseCharge, fuse_ndarray_charges, intersect
2222
import numpy as np
2323
import scipy as sp
2424
import itertools
2525
import time
26-
from typing import List, Union, Any, Tuple, Type, Optional, Dict, Iterable, Sequence
26+
from typing import List, Union, Any, Tuple, Type, Optional, Dict, Iterable, Sequence, Text
2727
Tensor = Any
2828

2929

30-
def get_flat_order(indices, order):
30+
def get_flat_order(indices: List[Index],
31+
order: Union[List[int], np.ndarray]) -> np.ndarray:
32+
"""
33+
Compute the flat order of the
34+
flattened `indices` corresponding to `order`.
35+
Args:
36+
indices: A list of `Index` objects.
37+
order: An order.
38+
Returns:
39+
The flat order of the flat indices correspondint
40+
to the `order` of `indices`.
41+
"""
3142
flat_charges, _ = get_flat_meta_data(indices)
3243
flat_labels = np.arange(len(flat_charges))
3344
cum_num_legs = np.append(0, np.cumsum([len(i.flat_charges) for i in indices]))
@@ -38,6 +49,12 @@ def get_flat_order(indices, order):
3849

3950

4051
def get_flat_meta_data(indices):
52+
"""
53+
Return charges and flows of flattened `indices`.
54+
Args:
55+
indices: A list of `Index` objects.
56+
57+
"""
4158
charges = []
4259
flows = []
4360
for i in indices:
@@ -818,6 +835,29 @@ def flat_flows(self):
818835
flat.extend(i.flat_flows)
819836
return flat
820837

838+
def __matmul__(self, other):
839+
840+
if self.rank != 2:
841+
raise ValueError('__matmul__ only implemented for matrices')
842+
843+
if other.rank != 2:
844+
raise ValueError('__matmul__ only implemented for matrices')
845+
return tensordot(self, other, ([1], [0]))
846+
847+
def conj(self):
848+
"""
849+
Transpose the tensor in place into the new order `order`.
850+
Args:
851+
order: The new order of indices.
852+
Returns:
853+
BlockSparseTensor: The transposed tensor.
854+
"""
855+
indices = [
856+
Index(i.flat_charges, list(np.logical_not(i.flat_flows)), i.name)
857+
for i in self.indices
858+
]
859+
return BlockSparseTensor(np.conj(self.data), indices)
860+
821861
def transpose(
822862
self,
823863
order: Union[List[int], np.ndarray],
@@ -839,7 +879,6 @@ def transpose(
839879
return BlockSparseTensor(self.data, self.indices)
840880
flat_charges, flat_flows = get_flat_meta_data(self.indices)
841881
flat_order = get_flat_order(self.indices, order)
842-
print(flat_order)
843882
tr_partition = _find_best_partition(
844883
[len(flat_charges[n]) for n in flat_order])
845884

@@ -938,8 +977,9 @@ def reshape(tensor: BlockSparseTensor,
938977
Reshape `tensor` into `shape`.
939978
`reshape` works essentially the same as the dense version, with the
940979
notable exception that the tensor can only be reshaped into a form
941-
compatible with its elementary indices. The elementary indices are
942-
the indices at the leaves of the `Index` objects `tensors.indices`.
980+
compatible with its elementary shape. The elementary shape is
981+
the shape determined by the flattened charges of all `Index` objects
982+
in `tensors.indices`.
943983
For example, while the following reshaping is possible for regular
944984
dense numpy tensor,
945985
```
@@ -948,14 +988,14 @@ def reshape(tensor: BlockSparseTensor,
948988
```
949989
the same code for BlockSparseTensor
950990
```
951-
q1 = np.random.randint(0,10,6)
952-
q2 = np.random.randint(0,10,6)
953-
q3 = np.random.randint(0,10,6)
954-
i1 = Index(charges=q1,flow=1)
955-
i2 = Index(charges=q2,flow=-1)
956-
i3 = Index(charges=q3,flow=1)
991+
q1 = U1Charge(np.random.randint(0,10,6))
992+
q2 = U1Charge(np.random.randint(0,10,6))
993+
q3 = U1Charge(np.random.randint(0,10,6))
994+
i1 = Index(charges=q1,flow=False)
995+
i2 = Index(charges=q2,flow=True)
996+
i3 = Index(charges=q3,flow=False)
957997
A=BlockSparseTensor.randn(indices=[i1,i2,i3])
958-
print(nA.shape) #prints (6,6,6)
998+
print(A.shape) #prints (6,6,6)
959999
reshape(A, (2,3,6,6)) #raises ValueError
9601000
```
9611001
raises a `ValueError` since (2,3,6,6)
@@ -975,8 +1015,8 @@ def reshape(tensor: BlockSparseTensor,
9751015
def transpose(tensor: BlockSparseTensor,
9761016
order: Union[List[int], np.ndarray]) -> "BlockSparseTensor":
9771017
"""
978-
Transpose `tensor` into the new order `order`. This routine currently shuffles
979-
data.
1018+
Transpose `tensor` into the new order `order`.
1019+
This routine currently shuffles data.
9801020
Args:
9811021
tensor: The tensor to be transposed.
9821022
order: The new order of indices.
@@ -1207,7 +1247,8 @@ def tensordot(
12071247
def svd(matrix: BlockSparseTensor,
12081248
full_matrices: Optional[bool] = True,
12091249
compute_uv: Optional[bool] = True,
1210-
hermitian: Optional[bool] = False):
1250+
hermitian: Optional[bool] = False
1251+
) -> Tuple[BlockSparseTensor, BlockSparseTensor, BlockSparseTensor]:
12111252
"""
12121253
Compute the singular value decomposition of `matrix`.
12131254
The matrix if factorized into `u * s * vh`, with
@@ -1220,10 +1261,14 @@ def svd(matrix: BlockSparseTensor,
12201261
and `v.shape[0]=s.shape[1]`
12211262
compute_yv: If `True`, return `u` and `v`.
12221263
hermitian: If `True`, assume hermiticity of `matrix`.
1264+
Returns:
1265+
If `compute_uv` is `True`: Three BlockSparseTensors `U,S,V`.
1266+
If `compute_uv` is `False`: A BlockSparseTensors `S` containing the
1267+
singular values.
12231268
"""
12241269

12251270
if matrix.rank != 2:
1226-
raise NotImplementedError("SVD currently supports only rank-2 tensors.")
1271+
raise NotImplementedError("svd currently supports only rank-2 tensors.")
12271272

12281273
flat_charges = matrix.indices[0]._charges + matrix.indices[1]._charges
12291274
flat_flows = matrix.flat_flows
@@ -1296,3 +1341,186 @@ def svd(matrix: BlockSparseTensor,
12961341
np.concatenate([np.ravel(v) for v in v_blocks]), indices_v)
12971342

12981343
return S
1344+
1345+
1346+
def qr(matrix: BlockSparseTensor, mode: Optional[Text] = 'reduced'
1347+
) -> [BlockSparseTensor, BlockSparseTensor]:
1348+
"""
1349+
Compute the qr decomposition of an `M` by `N` matrix `matrix`.
1350+
The matrix is factorized into `q*r`, with
1351+
`q` an orthogonal matrix and `r` an upper triangular matrix.
1352+
Args:
1353+
matrix: A matrix (i.e. a rank-2 tensor) of type `BlockSparseTensor`
1354+
mode : Can take values {'reduced', 'complete', 'r', 'raw'}.
1355+
If K = min(M, N), then
1356+
1357+
* 'reduced' : returns q, r with dimensions (M, K), (K, N) (default)
1358+
* 'complete' : returns q, r with dimensions (M, M), (M, N)
1359+
* 'r' : returns r only with dimensions (K, N)
1360+
1361+
Returns:
1362+
(BlockSparseTensor,BlockSparseTensor): If mode = `reduced` or `complete`
1363+
BlockSparseTensor: If mode = `r`.
1364+
"""
1365+
if mode == 'raw':
1366+
raise NotImplementedError('mode `raw` currenntly not supported')
1367+
if matrix.rank != 2:
1368+
raise NotImplementedError("qr currently supports only rank-2 tensors.")
1369+
1370+
flat_charges = matrix.indices[0]._charges + matrix.indices[1]._charges
1371+
flat_flows = matrix.flat_flows
1372+
partition = len(matrix.indices[0].flat_charges)
1373+
blocks, charges, shapes = _find_diagonal_sparse_blocks(
1374+
flat_charges, flat_flows, partition)
1375+
1376+
q_blocks = []
1377+
r_blocks = []
1378+
for n in range(len(blocks)):
1379+
out = np.linalg.qr(np.reshape(matrix.data[blocks[n]], shapes[:, n]), mode)
1380+
if mode in ('reduced', 'complete'):
1381+
q_blocks.append(out[0])
1382+
r_blocks.append(out[1])
1383+
elif mode == 'r':
1384+
r_blocks.append(out)
1385+
else:
1386+
raise ValueError('unknown value {} for input `mode`'.format(mode))
1387+
1388+
left_r_charge = charges.__new__(type(charges))
1389+
left_r_charge_labels = np.concatenate([
1390+
np.full(r_blocks[n].shape[0], fill_value=n, dtype=np.int16)
1391+
for n in range(len(r_blocks))
1392+
])
1393+
1394+
left_r_charge.__init__(charges.unique_charges, left_r_charge_labels,
1395+
charges.charge_types)
1396+
indices_r = [Index(left_r_charge, False), matrix.indices[1]]
1397+
1398+
R = BlockSparseTensor(
1399+
np.concatenate([np.ravel(r) for r in r_blocks]), indices_r)
1400+
if mode in ('reduced', 'complete'):
1401+
right_q_charge = charges.__new__(type(charges))
1402+
right_q_charge_labels = np.concatenate([
1403+
np.full(q_blocks[n].shape[1], fill_value=n, dtype=np.int16)
1404+
for n in range(len(q_blocks))
1405+
])
1406+
right_q_charge.__init__(charges.unique_charges, right_q_charge_labels,
1407+
charges.charge_types)
1408+
1409+
indices_q = [Index(right_q_charge, True), matrix.indices[0]]
1410+
#TODO: reuse data from _find_diagonal_sparse_blocks above
1411+
#to avoid the transpose
1412+
return BlockSparseTensor(
1413+
np.concatenate([np.ravel(q.T) for q in q_blocks]), indices_q).transpose(
1414+
(1, 0)), R
1415+
1416+
return R
1417+
1418+
1419+
def eigh(matrix: BlockSparseTensor,
1420+
UPLO: Optional[Text] = 'L') -> [BlockSparseTensor, BlockSparseTensor]:
1421+
"""
1422+
Compute the eigen decomposition of a hermitian `M` by `M` matrix `matrix`.
1423+
Args:
1424+
matrix: A matrix (i.e. a rank-2 tensor) of type `BlockSparseTensor`
1425+
1426+
Returns:
1427+
(BlockSparseTensor,BlockSparseTensor): The eigenvalues and eigenvectors
1428+
1429+
"""
1430+
if matrix.rank != 2:
1431+
raise NotImplementedError("qr currently supports only rank-2 tensors.")
1432+
1433+
flat_charges = matrix.indices[0]._charges + matrix.indices[1]._charges
1434+
flat_flows = matrix.flat_flows
1435+
partition = len(matrix.indices[0].flat_charges)
1436+
blocks, charges, shapes = _find_diagonal_sparse_blocks(
1437+
flat_charges, flat_flows, partition)
1438+
1439+
eigvals = []
1440+
v_blocks = []
1441+
for n in range(len(blocks)):
1442+
e, v = np.linalg.eigh(
1443+
np.reshape(matrix.data[blocks[n]], shapes[:, n]), UPLO)
1444+
eigvals.append(np.diag(e))
1445+
v_blocks.append(v)
1446+
1447+
left_v_charge = charges.__new__(type(charges))
1448+
left_v_charge_labels = np.concatenate([
1449+
np.full(v_blocks[n].shape[0], fill_value=n, dtype=np.int16)
1450+
for n in range(len(v_blocks))
1451+
])
1452+
1453+
left_v_charge.__init__(charges.unique_charges, left_v_charge_labels,
1454+
charges.charge_types)
1455+
indices_v = [Index(left_v_charge, False), matrix.indices[1]]
1456+
1457+
V = BlockSparseTensor(
1458+
np.concatenate([np.ravel(v) for v in v_blocks]), indices_v)
1459+
eigvalscharge = charges.__new__(type(charges))
1460+
eigvalscharge_labels = np.concatenate([
1461+
np.full(eigvals[n].shape[1], fill_value=n, dtype=np.int16)
1462+
for n in range(len(eigvals))
1463+
])
1464+
eigvalscharge.__init__(charges.unique_charges, eigvalscharge_labels,
1465+
charges.charge_types)
1466+
1467+
indices_q = [Index(eigvalscharge, True), matrix.indices[0]]
1468+
#TODO: reuse data from _find_diagonal_sparse_blocks above
1469+
#to avoid the transpose
1470+
return BlockSparseTensor(
1471+
np.concatenate([np.ravel(q.T) for q in eigvals]), indices_q).transpose(
1472+
(1, 0)), V
1473+
1474+
1475+
def eig(matrix: BlockSparseTensor) -> [BlockSparseTensor, BlockSparseTensor]:
1476+
"""
1477+
Compute the eigen decomposition of an `M` by `M` matrix `matrix`.
1478+
Args:
1479+
matrix: A matrix (i.e. a rank-2 tensor) of type `BlockSparseTensor`
1480+
1481+
Returns:
1482+
(BlockSparseTensor,BlockSparseTensor): The eigenvalues and eigenvectors
1483+
1484+
"""
1485+
if matrix.rank != 2:
1486+
raise NotImplementedError("qr currently supports only rank-2 tensors.")
1487+
1488+
flat_charges = matrix.indices[0]._charges + matrix.indices[1]._charges
1489+
flat_flows = matrix.flat_flows
1490+
partition = len(matrix.indices[0].flat_charges)
1491+
blocks, charges, shapes = _find_diagonal_sparse_blocks(
1492+
flat_charges, flat_flows, partition)
1493+
1494+
eigvals = []
1495+
v_blocks = []
1496+
for n in range(len(blocks)):
1497+
e, v = np.linalg.eig(np.reshape(matrix.data[blocks[n]], shapes[:, n]))
1498+
eigvals.append(np.diag(e))
1499+
v_blocks.append(v)
1500+
1501+
left_v_charge = charges.__new__(type(charges))
1502+
left_v_charge_labels = np.concatenate([
1503+
np.full(v_blocks[n].shape[0], fill_value=n, dtype=np.int16)
1504+
for n in range(len(v_blocks))
1505+
])
1506+
1507+
left_v_charge.__init__(charges.unique_charges, left_v_charge_labels,
1508+
charges.charge_types)
1509+
indices_v = [Index(left_v_charge, False), matrix.indices[1]]
1510+
1511+
V = BlockSparseTensor(
1512+
np.concatenate([np.ravel(v) for v in v_blocks]), indices_v)
1513+
eigvalscharge = charges.__new__(type(charges))
1514+
eigvalscharge_labels = np.concatenate([
1515+
np.full(eigvals[n].shape[1], fill_value=n, dtype=np.int16)
1516+
for n in range(len(eigvals))
1517+
])
1518+
eigvalscharge.__init__(charges.unique_charges, eigvalscharge_labels,
1519+
charges.charge_types)
1520+
1521+
indices_q = [Index(eigvalscharge, True), matrix.indices[0]]
1522+
#TODO: reuse data from _find_diagonal_sparse_blocks above
1523+
#to avoid the transpose
1524+
return BlockSparseTensor(
1525+
np.concatenate([np.ravel(q.T) for q in eigvals]), indices_q).transpose(
1526+
(1, 0)), V

0 commit comments

Comments
 (0)