Skip to content

Commit

Permalink
various changes to path machinery (#61)
Browse files Browse the repository at this point in the history
* path finding changes

* fix cupy a bit

* PathInfo: inherit from object and use join for repr

* update docstrings and var names

* update string formatting

* test rand_equation

* set seed for rand_equation test
  • Loading branch information
jcmgray authored Oct 2, 2018
1 parent 1db5a0a commit ccbdf6b
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 82 deletions.
6 changes: 3 additions & 3 deletions docs/source/greedy_path.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ It should be stressed these cases are quite rare and by default ``contract`` use
>>> B = np.random.rand(37, 51, 51, 59)
>>> C = np.random.rand(59, 27)
>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, path="greedy")
>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="greedy")
>>> print(desc)
Complete contraction: xyf,xtf,ytpf,fr->tpr
Naive scaling: 6
Expand All @@ -41,7 +41,7 @@ It should be stressed these cases are quite rare and by default ``contract`` use
4 False tpfx,xtf->tpf fr,tpf->tpr
4 GEMM tpf,fr->tpr tpr->tpr
>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, path="optimal")
>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal")
>>> print(desc)
Complete contraction: xyf,xtf,ytpf,fr->tpr
Expand All @@ -56,4 +56,4 @@ It should be stressed these cases are quite rare and by default ``contract`` use
--------------------------------------------------------------------------------
4 False xtf,xyf->tfy ytpf,fr,tfy->tpr
4 False tfy,ytpf->tfp fr,tfp->tpr
4 TDOT tfp,fr->tpr tpr->tpr
4 TDOT tfp,fr->tpr tpr->tpr
161 changes: 103 additions & 58 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,70 @@
__all__ = ["contract_path", "contract", "format_const_einsum_str", "ContractExpression", "shape_only", "shape_only"]


class PathInfo(object):

def __init__(self, contraction_list, input_subscripts, output_subscript,
indices, scale_list, naive_cost, opt_cost, size_list):
self.contraction_list = contraction_list
self.input_subscripts = input_subscripts
self.output_subscript = output_subscript
self.indices = indices
self.scale_list = scale_list
self.naive_cost = naive_cost
self.opt_cost = opt_cost
self.largest_intermediate = max(size_list)

def __repr__(self):
from decimal import Decimal

# naive costs / speedups can easily reach >~ 1e308, need exact arithmetic
naive_cost = Decimal(self.naive_cost)
opt_cost = Decimal(self.opt_cost)
speedup = naive_cost / opt_cost
largest_intermediate = Decimal(self.largest_intermediate)

# Return the path along with a nice string representation
overall_contraction = self.input_subscripts + "->" + self.output_subscript
header = ("scaling", "BLAS", "current", "remaining")

path_print = [
" Complete contraction: {}\n".format(overall_contraction),
" Naive scaling: {}\n".format(len(self.indices)),
" Optimized scaling: {}\n".format(max(self.scale_list)),
" Naive FLOP count: {:.3e}\n".format(naive_cost),
" Optimized FLOP count: {:.3e}\n".format(opt_cost),
" Theoretical speedup: {:3.3f}\n".format(speedup),
" Largest intermediate: {:.3e} elements\n".format(largest_intermediate),
"-" * 80 + "\n",
"{:>6} {:>11} {:>22} {:>37}\n".format(*header),
"-" * 80
]

for n, contraction in enumerate(self.contraction_list):
inds, idx_rm, einsum_str, remaining, do_blas = contraction
remaining_str = ",".join(remaining) + "->" + self.output_subscript
path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str)
path_print.append("\n{:>4} {:>14} {:>22} {:>37}".format(*path_run))

return "".join(path_print)


def _choose_memory_arg(memory_limit, size_list):
if memory_limit == 'max_input':
return max(size_list)

if memory_limit is None:
return int(1e20)

if memory_limit < 1:
if memory_limit == -1:
return int(1e20)
else:
raise ValueError("Memory limit must be larger than 0, or -1")

return int(memory_limit)


def contract_path(*operands, **kwargs):
"""
Evaluates the lowest cost einsum-like contraction order.
Expand All @@ -25,7 +89,7 @@ def contract_path(*operands, **kwargs):
Specifies the subscripts for summation.
*operands : list of array_like
These are the arrays for the operation.
path : bool or list, optional (default: ``auto``)
optimize : bool or list, optional (default: ``auto``)
Choose the type of path.
- if a list is given uses this as the path.
Expand Down Expand Up @@ -111,12 +175,18 @@ def contract_path(*operands, **kwargs):
"""

# Make sure all keywords are valid
valid_contract_kwargs = ['path', 'memory_limit', 'einsum_call', 'use_blas']
valid_contract_kwargs = ['optimize', 'path', 'memory_limit', 'einsum_call', 'use_blas']
unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_contract_kwargs]
if len(unknown_kwargs):
raise TypeError("einsum_path: Did not understand the following kwargs: %s" % unknown_kwargs)
raise TypeError("einsum_path: Did not understand the following kwargs: {}".format(unknown_kwargs))

if 'path' in kwargs:
import warnings
warnings.warn("The 'path' keyword argument is deprecated in favor of 'optimize'.", DeprecationWarning)
path_type = kwargs.pop('path')
else:
path_type = kwargs.pop('optimize', 'auto')

path_type = kwargs.pop('path', 'auto')
memory_limit = kwargs.pop('memory_limit', None)

# Hidden option, only einsum should call this
Expand All @@ -139,8 +209,8 @@ def contract_path(*operands, **kwargs):
sh = input_shps[tnum]

if len(sh) != len(term):
raise ValueError("Einstein sum subscript %s does not contain the "
"correct number of indices for operand %d." % (input_subscripts[tnum], tnum))
raise ValueError("Einstein sum subscript '{}' does not contain the "
"correct number of indices for operand {}.".format(input_subscripts[tnum], tnum))
for cnum, char in enumerate(term):
dim = sh[cnum]

Expand All @@ -149,25 +219,14 @@ def contract_path(*operands, **kwargs):
if dimension_dict[char] == 1:
dimension_dict[char] = dim
elif dim not in (1, dimension_dict[char]):
raise ValueError("Size of label '%s' for operand %d (%d) "
"does not match previous terms (%d)." % (char, tnum, dimension_dict[char], dim))
raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
"terms ({}).".format(char, tnum, dimension_dict[char], dim))
else:
dimension_dict[char] = dim

# Compute size of each input array plus the output array
size_list = [helpers.compute_size_by_dict(term, dimension_dict) for term in input_list + [output_subscript]]
out_size = max(size_list)

if memory_limit is None:
memory_arg = out_size
else:
if memory_limit < 1:
if memory_limit == -1:
memory_arg = int(1e20)
else:
raise ValueError("Memory limit must be larger than 0, or -1")
else:
memory_arg = int(memory_limit)
memory_arg = _choose_memory_arg(memory_limit, size_list)

num_ops = len(input_list)

Expand Down Expand Up @@ -195,7 +254,7 @@ def contract_path(*operands, **kwargs):
elif path_type in ("greedy", "opportunistic", "auto"):
path = paths.greedy(input_sets, output_set, dimension_dict, memory_arg)
else:
raise KeyError("Path name %s not found" % path_type)
raise KeyError("Path name '{}' not found".format(path_type))

cost_list = []
scale_list = []
Expand Down Expand Up @@ -247,26 +306,8 @@ def contract_path(*operands, **kwargs):
if einsum_call_arg:
return operands, contraction_list

# Return the path along with a nice string representation
overall_contraction = input_subscripts + "->" + output_subscript
header = ("scaling", "BLAS", "current", "remaining")

path_print = " Complete contraction: %s\n" % overall_contraction
path_print += " Naive scaling: %d\n" % len(indices)
path_print += " Optimized scaling: %d\n" % max(scale_list)
path_print += " Naive FLOP count: %.3e\n" % naive_cost
path_print += " Optimized FLOP count: %.3e\n" % opt_cost
path_print += " Theoretical speedup: %3.3f\n" % (naive_cost / float(opt_cost))
path_print += " Largest intermediate: %.3e elements\n" % max(size_list)
path_print += "-" * 80 + "\n"
path_print += "%6s %11s %22s %37s\n" % header
path_print += "-" * 80

for n, contraction in enumerate(contraction_list):
inds, idx_rm, einsum_str, remaining, do_blas = contraction
remaining_str = ",".join(remaining) + "->" + output_subscript
path_run = (scale_list[n], do_blas, einsum_str, remaining_str)
path_print += "\n%4d %14s %22s %37s" % path_run
path_print = PathInfo(contraction_list, input_subscripts, output_subscript,
indices, scale_list, naive_cost, opt_cost, size_list)

return path, path_print

Expand Down Expand Up @@ -350,11 +391,15 @@ def contract(*operands, **kwargs):
contracting the listed tensors. Scales exponentially with
the number of terms in the contraction.
memory_limit : int or None (default : None)
The upper limit of the size of tensor created, by default, this will be
memory_limit : {None, int, 'max_input'} (default: None)
Give the upper bound of the largest intermediate tensor contract will build.
By default (None) will size the ``memory_limit`` as the largest input tensor.
Users can also specify ``-1`` to allow arbitrarily large tensors to be built.
- None or -1 means there is no limit
- 'max_input' means the limit is set as largest input tensor
- a positive integer is taken as an explicit limit on the number of elements
The default is None. Note that imposing a limit can make contractions
exponentially slower to perform.
backend : str, optional (default: ``numpy``)
Which library to use to perform the required ``tensordot``, ``transpose``
and ``einsum`` calls. Should match the types of arrays supplied, See
Expand Down Expand Up @@ -407,14 +452,14 @@ def contract(*operands, **kwargs):
# Make sure remaining keywords are valid for einsum
unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_einsum_kwargs]
if len(unknown_kwargs):
raise TypeError("Did not understand the following kwargs: %s" % unknown_kwargs)
raise TypeError("Did not understand the following kwargs: {}".format(unknown_kwargs))

if gen_expression:
full_str = operands[0]

# Build the contraction list and operand
operands, contraction_list = contract_path(
*operands, path=optimize_arg, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas)
*operands, optimize=optimize_arg, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas)

# check if performing contraction or just building expression
if gen_expression:
Expand Down Expand Up @@ -630,13 +675,13 @@ def __call__(self, *arrays, **kwargs):

if kwargs:
raise ValueError("The only valid keyword arguments to a `ContractExpression` "
"call are `out=` or `backend=`. Got: %s." % kwargs)
"call are `out=` or `backend=`. Got: {}.".format(kwargs))

correct_num_args = self._full_num_args if evaluate_constants else self.num_args

if len(arrays) != correct_num_args:
raise ValueError("This `ContractExpression` takes exactly %s array arguments "
"but received %s." % (self.num_args, len(arrays)))
raise ValueError("This `ContractExpression` takes exactly {} array arguments "
"but received {}.".format(self.num_args, len(arrays)))

if self._constants_dict and not evaluate_constants:
# fill in the missing non-constant terms with newly supplied arrays
Expand All @@ -657,7 +702,7 @@ def __call__(self, *arrays, **kwargs):
original_msg = str(err.args) if err.args else ""
msg = ("Internal error while evaluating `ContractExpression`. Note that few checks are performed"
" - the number and rank of the array arguments must match the original expression. "
"The internal error was: '%s'" % original_msg, )
"The internal error was: '{}'".format(original_msg), )
err.args = msg
raise

Expand All @@ -669,13 +714,13 @@ def __repr__(self):
return "<ContractExpression('{}'{})>".format(self.contraction, constants_repr)

def __str__(self):
s = self.__repr__()
s = [self.__repr__()]
for i, c in enumerate(self.contraction_list):
s += "\n %i. " % (i + 1)
s += "'%s'" % c[2] + (" [%s]" % c[-1] if c[-1] else "")
s.append("\n {}. ".format(i + 1))
s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else ""))
if self.einsum_kwargs:
s += "\neinsum_kwargs=%s" % self.einsum_kwargs
return s
s.append("\neinsum_kwargs={}".format(self.einsum_kwargs))
return "".join(s)


def shape_only(shape):
Expand Down Expand Up @@ -751,8 +796,8 @@ def contract_expression(subscripts, *shapes, **kwargs):

for arg in ('out', 'backend'):
if kwargs.get(arg, None) is not None:
raise ValueError("'%s' should only be specified when calling a "
"`ContractExpression`, not when building it." % arg)
raise ValueError("'{}' should only be specified when calling a "
"`ContractExpression`, not when building it.".format(arg))

kwargs['_gen_expression'] = True

Expand Down
Loading

0 comments on commit ccbdf6b

Please sign in to comment.