-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
various changes to path machinery #61
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
8fec17a
path finding changes
jcmgray f8bdece
PathInfo: inherit from object and use join for repr
jcmgray 9d540b4
update docstrings and var names
jcmgray ddd47b7
update string formatting
jcmgray 809a3b1
test rand_equation
jcmgray f297d5f
set seed for rand_equation test
jcmgray File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,70 @@ | |
__all__ = ["contract_path", "contract", "format_const_einsum_str", "ContractExpression", "shape_only", "shape_only"] | ||
|
||
|
||
class PathInfo(object): | ||
|
||
def __init__(self, contraction_list, input_subscripts, output_subscript, | ||
indices, scale_list, naive_cost, opt_cost, size_list): | ||
self.contraction_list = contraction_list | ||
self.input_subscripts = input_subscripts | ||
self.output_subscript = output_subscript | ||
self.indices = indices | ||
self.scale_list = scale_list | ||
self.naive_cost = naive_cost | ||
self.opt_cost = opt_cost | ||
self.largest_intermediate = max(size_list) | ||
|
||
def __repr__(self): | ||
from decimal import Decimal | ||
|
||
# naive costs / speedups can easily reach >~ 1e308, need exact arithmetic | ||
naive_cost = Decimal(self.naive_cost) | ||
opt_cost = Decimal(self.opt_cost) | ||
speedup = naive_cost / opt_cost | ||
largest_intermediate = Decimal(self.largest_intermediate) | ||
|
||
# Return the path along with a nice string representation | ||
overall_contraction = self.input_subscripts + "->" + self.output_subscript | ||
header = ("scaling", "BLAS", "current", "remaining") | ||
|
||
path_print = [ | ||
" Complete contraction: {}\n".format(overall_contraction), | ||
" Naive scaling: {}\n".format(len(self.indices)), | ||
" Optimized scaling: {}\n".format(max(self.scale_list)), | ||
" Naive FLOP count: {:.3e}\n".format(naive_cost), | ||
" Optimized FLOP count: {:.3e}\n".format(opt_cost), | ||
" Theoretical speedup: {:3.3f}\n".format(speedup), | ||
" Largest intermediate: {:.3e} elements\n".format(largest_intermediate), | ||
"-" * 80 + "\n", | ||
"{:>6} {:>11} {:>22} {:>37}\n".format(*header), | ||
"-" * 80 | ||
] | ||
|
||
for n, contraction in enumerate(self.contraction_list): | ||
inds, idx_rm, einsum_str, remaining, do_blas = contraction | ||
remaining_str = ",".join(remaining) + "->" + self.output_subscript | ||
path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str) | ||
path_print.append("\n{:>4} {:>14} {:>22} {:>37}".format(*path_run)) | ||
|
||
return "".join(path_print) | ||
|
||
|
||
def _choose_memory_arg(memory_limit, size_list): | ||
if memory_limit == 'max_input': | ||
return max(size_list) | ||
|
||
if memory_limit is None: | ||
return int(1e20) | ||
|
||
if memory_limit < 1: | ||
if memory_limit == -1: | ||
return int(1e20) | ||
else: | ||
raise ValueError("Memory limit must be larger than 0, or -1") | ||
|
||
return int(memory_limit) | ||
|
||
|
||
def contract_path(*operands, **kwargs): | ||
""" | ||
Evaluates the lowest cost einsum-like contraction order. | ||
|
@@ -25,7 +89,7 @@ def contract_path(*operands, **kwargs): | |
Specifies the subscripts for summation. | ||
*operands : list of array_like | ||
These are the arrays for the operation. | ||
path : bool or list, optional (default: ``auto``) | ||
optimize : bool or list, optional (default: ``auto``) | ||
Choose the type of path. | ||
|
||
- if a list is given uses this as the path. | ||
|
@@ -111,12 +175,18 @@ def contract_path(*operands, **kwargs): | |
""" | ||
|
||
# Make sure all keywords are valid | ||
valid_contract_kwargs = ['path', 'memory_limit', 'einsum_call', 'use_blas'] | ||
valid_contract_kwargs = ['optimize', 'path', 'memory_limit', 'einsum_call', 'use_blas'] | ||
unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_contract_kwargs] | ||
if len(unknown_kwargs): | ||
raise TypeError("einsum_path: Did not understand the following kwargs: %s" % unknown_kwargs) | ||
raise TypeError("einsum_path: Did not understand the following kwargs: {}".format(unknown_kwargs)) | ||
|
||
if 'path' in kwargs: | ||
import warnings | ||
warnings.warn("The 'path' keyword argument is deprecated in favor of 'optimize'.", DeprecationWarning) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree, this is a good change overall. |
||
path_type = kwargs.pop('path') | ||
else: | ||
path_type = kwargs.pop('optimize', 'auto') | ||
|
||
path_type = kwargs.pop('path', 'auto') | ||
memory_limit = kwargs.pop('memory_limit', None) | ||
|
||
# Hidden option, only einsum should call this | ||
|
@@ -139,8 +209,8 @@ def contract_path(*operands, **kwargs): | |
sh = input_shps[tnum] | ||
|
||
if len(sh) != len(term): | ||
raise ValueError("Einstein sum subscript %s does not contain the " | ||
"correct number of indices for operand %d." % (input_subscripts[tnum], tnum)) | ||
raise ValueError("Einstein sum subscript '{}' does not contain the " | ||
"correct number of indices for operand {}.".format(input_subscripts[tnum], tnum)) | ||
for cnum, char in enumerate(term): | ||
dim = sh[cnum] | ||
|
||
|
@@ -149,25 +219,14 @@ def contract_path(*operands, **kwargs): | |
if dimension_dict[char] == 1: | ||
dimension_dict[char] = dim | ||
elif dim not in (1, dimension_dict[char]): | ||
raise ValueError("Size of label '%s' for operand %d (%d) " | ||
"does not match previous terms (%d)." % (char, tnum, dimension_dict[char], dim)) | ||
raise ValueError("Size of label '{}' for operand {} ({}) does not match previous " | ||
"terms ({}).".format(char, tnum, dimension_dict[char], dim)) | ||
else: | ||
dimension_dict[char] = dim | ||
|
||
# Compute size of each input array plus the output array | ||
size_list = [helpers.compute_size_by_dict(term, dimension_dict) for term in input_list + [output_subscript]] | ||
out_size = max(size_list) | ||
|
||
if memory_limit is None: | ||
memory_arg = out_size | ||
else: | ||
if memory_limit < 1: | ||
if memory_limit == -1: | ||
memory_arg = int(1e20) | ||
else: | ||
raise ValueError("Memory limit must be larger than 0, or -1") | ||
else: | ||
memory_arg = int(memory_limit) | ||
memory_arg = _choose_memory_arg(memory_limit, size_list) | ||
|
||
num_ops = len(input_list) | ||
|
||
|
@@ -195,7 +254,7 @@ def contract_path(*operands, **kwargs): | |
elif path_type in ("greedy", "opportunistic", "auto"): | ||
path = paths.greedy(input_sets, output_set, dimension_dict, memory_arg) | ||
else: | ||
raise KeyError("Path name %s not found" % path_type) | ||
raise KeyError("Path name '{}' not found".format(path_type)) | ||
|
||
cost_list = [] | ||
scale_list = [] | ||
|
@@ -247,26 +306,8 @@ def contract_path(*operands, **kwargs): | |
if einsum_call_arg: | ||
return operands, contraction_list | ||
|
||
# Return the path along with a nice string representation | ||
overall_contraction = input_subscripts + "->" + output_subscript | ||
header = ("scaling", "BLAS", "current", "remaining") | ||
|
||
path_print = " Complete contraction: %s\n" % overall_contraction | ||
path_print += " Naive scaling: %d\n" % len(indices) | ||
path_print += " Optimized scaling: %d\n" % max(scale_list) | ||
path_print += " Naive FLOP count: %.3e\n" % naive_cost | ||
path_print += " Optimized FLOP count: %.3e\n" % opt_cost | ||
path_print += " Theoretical speedup: %3.3f\n" % (naive_cost / float(opt_cost)) | ||
path_print += " Largest intermediate: %.3e elements\n" % max(size_list) | ||
path_print += "-" * 80 + "\n" | ||
path_print += "%6s %11s %22s %37s\n" % header | ||
path_print += "-" * 80 | ||
|
||
for n, contraction in enumerate(contraction_list): | ||
inds, idx_rm, einsum_str, remaining, do_blas = contraction | ||
remaining_str = ",".join(remaining) + "->" + output_subscript | ||
path_run = (scale_list[n], do_blas, einsum_str, remaining_str) | ||
path_print += "\n%4d %14s %22s %37s" % path_run | ||
path_print = PathInfo(contraction_list, input_subscripts, output_subscript, | ||
indices, scale_list, naive_cost, opt_cost, size_list) | ||
|
||
return path, path_print | ||
|
||
|
@@ -350,11 +391,15 @@ def contract(*operands, **kwargs): | |
contracting the listed tensors. Scales exponentially with | ||
the number of terms in the contraction. | ||
|
||
memory_limit : int or None (default : None) | ||
The upper limit of the size of tensor created, by default, this will be | ||
memory_limit : {None, int, 'max_input'} (default: None) | ||
Give the upper bound of the largest intermediate tensor contract will build. | ||
By default (None) will size the ``memory_limit`` as the largest input tensor. | ||
Users can also specify ``-1`` to allow arbitrarily large tensors to be built. | ||
|
||
- None or -1 means there is no limit | ||
- 'max_input' means the limit is set as largest input tensor | ||
- a positive integer is taken as an explicit limit on the number of elements | ||
|
||
The default is None. Note that imposing a limit can make contractions | ||
exponentially slower to perform. | ||
backend : str, optional (default: ``numpy``) | ||
Which library to use to perform the required ``tensordot``, ``transpose`` | ||
and ``einsum`` calls. Should match the types of arrays supplied, See | ||
|
@@ -407,14 +452,14 @@ def contract(*operands, **kwargs): | |
# Make sure remaining keywords are valid for einsum | ||
unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_einsum_kwargs] | ||
if len(unknown_kwargs): | ||
raise TypeError("Did not understand the following kwargs: %s" % unknown_kwargs) | ||
raise TypeError("Did not understand the following kwargs: {}".format(unknown_kwargs)) | ||
|
||
if gen_expression: | ||
full_str = operands[0] | ||
|
||
# Build the contraction list and operand | ||
operands, contraction_list = contract_path( | ||
*operands, path=optimize_arg, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas) | ||
*operands, optimize=optimize_arg, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas) | ||
|
||
# check if performing contraction or just building expression | ||
if gen_expression: | ||
|
@@ -630,13 +675,13 @@ def __call__(self, *arrays, **kwargs): | |
|
||
if kwargs: | ||
raise ValueError("The only valid keyword arguments to a `ContractExpression` " | ||
"call are `out=` or `backend=`. Got: %s." % kwargs) | ||
"call are `out=` or `backend=`. Got: {}.".format(kwargs)) | ||
|
||
correct_num_args = self._full_num_args if evaluate_constants else self.num_args | ||
|
||
if len(arrays) != correct_num_args: | ||
raise ValueError("This `ContractExpression` takes exactly %s array arguments " | ||
"but received %s." % (self.num_args, len(arrays))) | ||
raise ValueError("This `ContractExpression` takes exactly {} array arguments " | ||
"but received {}.".format(self.num_args, len(arrays))) | ||
|
||
if self._constants_dict and not evaluate_constants: | ||
# fill in the missing non-constant terms with newly supplied arrays | ||
|
@@ -657,7 +702,7 @@ def __call__(self, *arrays, **kwargs): | |
original_msg = str(err.args) if err.args else "" | ||
msg = ("Internal error while evaluating `ContractExpression`. Note that few checks are performed" | ||
" - the number and rank of the array arguments must match the original expression. " | ||
"The internal error was: '%s'" % original_msg, ) | ||
"The internal error was: '{}'".format(original_msg), ) | ||
err.args = msg | ||
raise | ||
|
||
|
@@ -669,13 +714,13 @@ def __repr__(self): | |
return "<ContractExpression('{}'{})>".format(self.contraction, constants_repr) | ||
|
||
def __str__(self): | ||
s = self.__repr__() | ||
s = [self.__repr__()] | ||
for i, c in enumerate(self.contraction_list): | ||
s += "\n %i. " % (i + 1) | ||
s += "'%s'" % c[2] + (" [%s]" % c[-1] if c[-1] else "") | ||
s.append("\n {}. ".format(i + 1)) | ||
s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else "")) | ||
if self.einsum_kwargs: | ||
s += "\neinsum_kwargs=%s" % self.einsum_kwargs | ||
return s | ||
s.append("\neinsum_kwargs={}".format(self.einsum_kwargs)) | ||
return "".join(s) | ||
|
||
|
||
def shape_only(shape): | ||
|
@@ -751,8 +796,8 @@ def contract_expression(subscripts, *shapes, **kwargs): | |
|
||
for arg in ('out', 'backend'): | ||
if kwargs.get(arg, None) is not None: | ||
raise ValueError("'%s' should only be specified when calling a " | ||
"`ContractExpression`, not when building it." % arg) | ||
raise ValueError("'{}' should only be specified when calling a " | ||
"`ContractExpression`, not when building it.".format(arg)) | ||
|
||
kwargs['_gen_expression'] = True | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for
format
. If you thinking about it can we replace the old%
syntax elsewhere through the code so we can be consistent? Ancillary point, we can spin this off in its own issue.