Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

various changes to path machinery #61

Merged
merged 6 commits into from
Oct 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/greedy_path.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ It should be stressed these cases are quite rare and by default ``contract`` use
>>> B = np.random.rand(37, 51, 51, 59)
>>> C = np.random.rand(59, 27)

>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, path="greedy")
>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="greedy")
>>> print(desc)
Complete contraction: xyf,xtf,ytpf,fr->tpr
Naive scaling: 6
Expand All @@ -41,7 +41,7 @@ It should be stressed these cases are quite rare and by default ``contract`` use
4 False tpfx,xtf->tpf fr,tpf->tpr
4 GEMM tpf,fr->tpr tpr->tpr

>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, path="optimal")
>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal")
>>> print(desc)

Complete contraction: xyf,xtf,ytpf,fr->tpr
Expand All @@ -56,4 +56,4 @@ It should be stressed these cases are quite rare and by default ``contract`` use
--------------------------------------------------------------------------------
4 False xtf,xyf->tfy ytpf,fr,tfy->tpr
4 False tfy,ytpf->tfp fr,tfp->tpr
4 TDOT tfp,fr->tpr tpr->tpr
4 TDOT tfp,fr->tpr tpr->tpr
161 changes: 103 additions & 58 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,70 @@
__all__ = ["contract_path", "contract", "format_const_einsum_str", "ContractExpression", "shape_only", "shape_only"]


class PathInfo(object):

def __init__(self, contraction_list, input_subscripts, output_subscript,
indices, scale_list, naive_cost, opt_cost, size_list):
self.contraction_list = contraction_list
self.input_subscripts = input_subscripts
self.output_subscript = output_subscript
self.indices = indices
self.scale_list = scale_list
self.naive_cost = naive_cost
self.opt_cost = opt_cost
self.largest_intermediate = max(size_list)

def __repr__(self):
from decimal import Decimal

# naive costs / speedups can easily reach >~ 1e308, need exact arithmetic
naive_cost = Decimal(self.naive_cost)
opt_cost = Decimal(self.opt_cost)
speedup = naive_cost / opt_cost
largest_intermediate = Decimal(self.largest_intermediate)

# Return the path along with a nice string representation
overall_contraction = self.input_subscripts + "->" + self.output_subscript
header = ("scaling", "BLAS", "current", "remaining")

path_print = [
" Complete contraction: {}\n".format(overall_contraction),
" Naive scaling: {}\n".format(len(self.indices)),
" Optimized scaling: {}\n".format(max(self.scale_list)),
" Naive FLOP count: {:.3e}\n".format(naive_cost),
" Optimized FLOP count: {:.3e}\n".format(opt_cost),
" Theoretical speedup: {:3.3f}\n".format(speedup),
" Largest intermediate: {:.3e} elements\n".format(largest_intermediate),
"-" * 80 + "\n",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for format. If you thinking about it can we replace the old % syntax elsewhere through the code so we can be consistent? Ancillary point, we can spin this off in its own issue.

"{:>6} {:>11} {:>22} {:>37}\n".format(*header),
"-" * 80
]

for n, contraction in enumerate(self.contraction_list):
inds, idx_rm, einsum_str, remaining, do_blas = contraction
remaining_str = ",".join(remaining) + "->" + self.output_subscript
path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str)
path_print.append("\n{:>4} {:>14} {:>22} {:>37}".format(*path_run))

return "".join(path_print)


def _choose_memory_arg(memory_limit, size_list):
if memory_limit == 'max_input':
return max(size_list)

if memory_limit is None:
return int(1e20)

if memory_limit < 1:
if memory_limit == -1:
return int(1e20)
else:
raise ValueError("Memory limit must be larger than 0, or -1")

return int(memory_limit)


def contract_path(*operands, **kwargs):
"""
Evaluates the lowest cost einsum-like contraction order.
Expand All @@ -25,7 +89,7 @@ def contract_path(*operands, **kwargs):
Specifies the subscripts for summation.
*operands : list of array_like
These are the arrays for the operation.
path : bool or list, optional (default: ``auto``)
optimize : bool or list, optional (default: ``auto``)
Choose the type of path.

- if a list is given uses this as the path.
Expand Down Expand Up @@ -111,12 +175,18 @@ def contract_path(*operands, **kwargs):
"""

# Make sure all keywords are valid
valid_contract_kwargs = ['path', 'memory_limit', 'einsum_call', 'use_blas']
valid_contract_kwargs = ['optimize', 'path', 'memory_limit', 'einsum_call', 'use_blas']
unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_contract_kwargs]
if len(unknown_kwargs):
raise TypeError("einsum_path: Did not understand the following kwargs: %s" % unknown_kwargs)
raise TypeError("einsum_path: Did not understand the following kwargs: {}".format(unknown_kwargs))

if 'path' in kwargs:
import warnings
warnings.warn("The 'path' keyword argument is deprecated in favor of 'optimize'.", DeprecationWarning)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, this is a good change overall.

path_type = kwargs.pop('path')
else:
path_type = kwargs.pop('optimize', 'auto')

path_type = kwargs.pop('path', 'auto')
memory_limit = kwargs.pop('memory_limit', None)

# Hidden option, only einsum should call this
Expand All @@ -139,8 +209,8 @@ def contract_path(*operands, **kwargs):
sh = input_shps[tnum]

if len(sh) != len(term):
raise ValueError("Einstein sum subscript %s does not contain the "
"correct number of indices for operand %d." % (input_subscripts[tnum], tnum))
raise ValueError("Einstein sum subscript '{}' does not contain the "
"correct number of indices for operand {}.".format(input_subscripts[tnum], tnum))
for cnum, char in enumerate(term):
dim = sh[cnum]

Expand All @@ -149,25 +219,14 @@ def contract_path(*operands, **kwargs):
if dimension_dict[char] == 1:
dimension_dict[char] = dim
elif dim not in (1, dimension_dict[char]):
raise ValueError("Size of label '%s' for operand %d (%d) "
"does not match previous terms (%d)." % (char, tnum, dimension_dict[char], dim))
raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
"terms ({}).".format(char, tnum, dimension_dict[char], dim))
else:
dimension_dict[char] = dim

# Compute size of each input array plus the output array
size_list = [helpers.compute_size_by_dict(term, dimension_dict) for term in input_list + [output_subscript]]
out_size = max(size_list)

if memory_limit is None:
memory_arg = out_size
else:
if memory_limit < 1:
if memory_limit == -1:
memory_arg = int(1e20)
else:
raise ValueError("Memory limit must be larger than 0, or -1")
else:
memory_arg = int(memory_limit)
memory_arg = _choose_memory_arg(memory_limit, size_list)

num_ops = len(input_list)

Expand Down Expand Up @@ -195,7 +254,7 @@ def contract_path(*operands, **kwargs):
elif path_type in ("greedy", "opportunistic", "auto"):
path = paths.greedy(input_sets, output_set, dimension_dict, memory_arg)
else:
raise KeyError("Path name %s not found" % path_type)
raise KeyError("Path name '{}' not found".format(path_type))

cost_list = []
scale_list = []
Expand Down Expand Up @@ -247,26 +306,8 @@ def contract_path(*operands, **kwargs):
if einsum_call_arg:
return operands, contraction_list

# Return the path along with a nice string representation
overall_contraction = input_subscripts + "->" + output_subscript
header = ("scaling", "BLAS", "current", "remaining")

path_print = " Complete contraction: %s\n" % overall_contraction
path_print += " Naive scaling: %d\n" % len(indices)
path_print += " Optimized scaling: %d\n" % max(scale_list)
path_print += " Naive FLOP count: %.3e\n" % naive_cost
path_print += " Optimized FLOP count: %.3e\n" % opt_cost
path_print += " Theoretical speedup: %3.3f\n" % (naive_cost / float(opt_cost))
path_print += " Largest intermediate: %.3e elements\n" % max(size_list)
path_print += "-" * 80 + "\n"
path_print += "%6s %11s %22s %37s\n" % header
path_print += "-" * 80

for n, contraction in enumerate(contraction_list):
inds, idx_rm, einsum_str, remaining, do_blas = contraction
remaining_str = ",".join(remaining) + "->" + output_subscript
path_run = (scale_list[n], do_blas, einsum_str, remaining_str)
path_print += "\n%4d %14s %22s %37s" % path_run
path_print = PathInfo(contraction_list, input_subscripts, output_subscript,
indices, scale_list, naive_cost, opt_cost, size_list)

return path, path_print

Expand Down Expand Up @@ -350,11 +391,15 @@ def contract(*operands, **kwargs):
contracting the listed tensors. Scales exponentially with
the number of terms in the contraction.

memory_limit : int or None (default : None)
The upper limit of the size of tensor created, by default, this will be
memory_limit : {None, int, 'max_input'} (default: None)
Give the upper bound of the largest intermediate tensor contract will build.
By default (None) will size the ``memory_limit`` as the largest input tensor.
Users can also specify ``-1`` to allow arbitrarily large tensors to be built.

- None or -1 means there is no limit
- 'max_input' means the limit is set as largest input tensor
- a positive integer is taken as an explicit limit on the number of elements

The default is None. Note that imposing a limit can make contractions
exponentially slower to perform.
backend : str, optional (default: ``numpy``)
Which library to use to perform the required ``tensordot``, ``transpose``
and ``einsum`` calls. Should match the types of arrays supplied, See
Expand Down Expand Up @@ -407,14 +452,14 @@ def contract(*operands, **kwargs):
# Make sure remaining keywords are valid for einsum
unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_einsum_kwargs]
if len(unknown_kwargs):
raise TypeError("Did not understand the following kwargs: %s" % unknown_kwargs)
raise TypeError("Did not understand the following kwargs: {}".format(unknown_kwargs))

if gen_expression:
full_str = operands[0]

# Build the contraction list and operand
operands, contraction_list = contract_path(
*operands, path=optimize_arg, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas)
*operands, optimize=optimize_arg, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas)

# check if performing contraction or just building expression
if gen_expression:
Expand Down Expand Up @@ -630,13 +675,13 @@ def __call__(self, *arrays, **kwargs):

if kwargs:
raise ValueError("The only valid keyword arguments to a `ContractExpression` "
"call are `out=` or `backend=`. Got: %s." % kwargs)
"call are `out=` or `backend=`. Got: {}.".format(kwargs))

correct_num_args = self._full_num_args if evaluate_constants else self.num_args

if len(arrays) != correct_num_args:
raise ValueError("This `ContractExpression` takes exactly %s array arguments "
"but received %s." % (self.num_args, len(arrays)))
raise ValueError("This `ContractExpression` takes exactly {} array arguments "
"but received {}.".format(self.num_args, len(arrays)))

if self._constants_dict and not evaluate_constants:
# fill in the missing non-constant terms with newly supplied arrays
Expand All @@ -657,7 +702,7 @@ def __call__(self, *arrays, **kwargs):
original_msg = str(err.args) if err.args else ""
msg = ("Internal error while evaluating `ContractExpression`. Note that few checks are performed"
" - the number and rank of the array arguments must match the original expression. "
"The internal error was: '%s'" % original_msg, )
"The internal error was: '{}'".format(original_msg), )
err.args = msg
raise

Expand All @@ -669,13 +714,13 @@ def __repr__(self):
return "<ContractExpression('{}'{})>".format(self.contraction, constants_repr)

def __str__(self):
s = self.__repr__()
s = [self.__repr__()]
for i, c in enumerate(self.contraction_list):
s += "\n %i. " % (i + 1)
s += "'%s'" % c[2] + (" [%s]" % c[-1] if c[-1] else "")
s.append("\n {}. ".format(i + 1))
s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else ""))
if self.einsum_kwargs:
s += "\neinsum_kwargs=%s" % self.einsum_kwargs
return s
s.append("\neinsum_kwargs={}".format(self.einsum_kwargs))
return "".join(s)


def shape_only(shape):
Expand Down Expand Up @@ -751,8 +796,8 @@ def contract_expression(subscripts, *shapes, **kwargs):

for arg in ('out', 'backend'):
if kwargs.get(arg, None) is not None:
raise ValueError("'%s' should only be specified when calling a "
"`ContractExpression`, not when building it." % arg)
raise ValueError("'{}' should only be specified when calling a "
"`ContractExpression`, not when building it.".format(arg))

kwargs['_gen_expression'] = True

Expand Down
Loading