Skip to content

Commit

Permalink
contraction_list: only keep remaining for last 20 contractions (#149)
Browse files Browse the repository at this point in the history
* contraction_list: only keep remaining for last 10 contractions

* match tab spacing

* catch ValueError for numpy 'order' change

* change to store and show last 20 'remaining terms' in PathInfo
  • Loading branch information
jcmgray authored Jul 19, 2020
1 parent 2cbd28e commit cb72a33
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
18 changes: 9 additions & 9 deletions opt_einsum/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@


def _get_jax_and_to_jax():
global _JAX
if _JAX is None:
import jax
global _JAX
if _JAX is None:
import jax

@to_backend_cache_wrap
@jax.jit
def to_jax(x):
return x
@to_backend_cache_wrap
@jax.jit
def to_jax(x):
return x

_JAX = jax, to_jax
_JAX = jax, to_jax

return _JAX
return _JAX


def build_expression(_, expr): # pragma: no cover
Expand Down
25 changes: 19 additions & 6 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,22 @@ def __repr__(self):
" Complete contraction: {}\n".format(self.eq), " Naive scaling: {}\n".format(len(self.indices)),
" Optimized scaling: {}\n".format(max(self.scale_list)), " Naive FLOP count: {:.3e}\n".format(
self.naive_cost), " Optimized FLOP count: {:.3e}\n".format(self.opt_cost),
" Theoretical speedup: {:3.3f}\n".format(self.speedup),
" Theoretical speedup: {:.3e}\n".format(self.speedup),
" Largest intermediate: {:.3e} elements\n".format(self.largest_intermediate), "-" * 80 + "\n",
"{:>6} {:>11} {:>22} {:>37}\n".format(*header), "-" * 80
]

for n, contraction in enumerate(self.contraction_list):
inds, idx_rm, einsum_str, remaining, do_blas = contraction
remaining_str = ",".join(remaining) + "->" + self.output_subscript
path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str)
path_print.append("\n{:>4} {:>14} {:>22} {:>37}".format(*path_run))

if remaining is not None:
remaining_str = ",".join(remaining) + "->" + self.output_subscript
else:
remaining_str = "..."
size_remaining = max(0, 56 - max(22, len(einsum_str)))

path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str, size_remaining)
path_print.append("\n{:>4} {:>14} {:>22} {:>{}}".format(*path_run))

return "".join(path_print)

Expand Down Expand Up @@ -303,7 +309,14 @@ def contract_path(*operands, **kwargs):

einsum_str = ",".join(tmp_inputs) + "->" + idx_result

contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
# for large expressions saving the remaining terms at each step can
# incur a large memory footprint - and also be messy to print
if len(input_list) <= 20:
remaining = tuple(input_list)
else:
remaining = None

contraction = (contract_inds, idx_removed, einsum_str, remaining, do_blas)
contraction_list.append(contraction)

opt_cost = sum(cost_list)
Expand Down Expand Up @@ -529,7 +542,7 @@ def _core_contract(operands, contraction_list, backend='auto', evaluate_constant

# Start contraction loop
for num, contraction in enumerate(contraction_list):
inds, idx_rm, einsum_str, remaining, blas_flag = contraction
inds, idx_rm, einsum_str, _, blas_flag = contraction

# check if we are performing the pre-pass of an expression with constants,
# if so, break out upon finding first non-constant (None) operand
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_printing():
views = helpers.build_views(string)

ein = contract_path(string, *views)
assert len(str(ein[1])) == 726
assert len(str(ein[1])) == 728


@pytest.mark.parametrize("string", tests)
Expand Down
3 changes: 2 additions & 1 deletion opt_einsum/tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def test_type_errors():
contract("", 0, out='test')

# order parameter must be a valid order
with pytest.raises(TypeError):
# changed in Numpy 1.19, see https://github.com/numpy/numpy/commit/35b0a051c19265f5643f6011ee11e31d30c8bc4c
with pytest.raises((TypeError, ValueError)):
contract("", 0, order='W')

# casting parameter must be a valid casting
Expand Down

0 comments on commit cb72a33

Please sign in to comment.