Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make get_symbol(large_number) work in Python2 #48

Merged
merged 5 commits into from
Aug 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions opt_einsum/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Python 2/3 compatability shim

try:
# Python 2
get_chr = unichr
strings = (str, type(get_chr(300)))
except NameError:
# Python 3
get_chr = chr
strings = str
5 changes: 3 additions & 2 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from . import backends
from . import blas
from . import compat
from . import helpers
from . import parser
from . import paths
Expand Down Expand Up @@ -176,7 +177,7 @@ def contract_path(*operands, **kwargs):
naive_cost = helpers.flop_count(indices, inner_product, num_ops, dimension_dict)

# Compute the path
if not isinstance(path_type, str):
if not isinstance(path_type, compat.strings):
path = path_type
elif num_ops == 1:
# Nothing to be optimized
Expand Down Expand Up @@ -274,7 +275,7 @@ def _einsum(*operands, **kwargs):
"""
fn = backends.get_func('einsum', kwargs.pop('backend', 'numpy'))

if not isinstance(operands[0], str):
if not isinstance(operands[0], compat.strings):
return fn(*operands, **kwargs)

einsum_str, operands = operands[0], operands[1:]
Expand Down
6 changes: 4 additions & 2 deletions opt_einsum/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import numpy as np

from . import compat

einsum_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'


Expand Down Expand Up @@ -40,7 +42,7 @@ def get_symbol(i):
"""
if i < 52:
return einsum_symbols_base[i]
return chr(i + 140)
return compat.get_chr(i + 140)


def gen_unused_symbols(used, n):
Expand Down Expand Up @@ -133,7 +135,7 @@ def parse_einsum_input(operands):
if len(operands) == 0:
raise ValueError("No input operands")

if isinstance(operands[0], str):
if isinstance(operands[0], compat.strings):
subscripts = operands[0].replace(" ", "")
operands = [possibly_convert_to_numpy(x) for x in operands[1:]]

Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def cached_einsum(*args, **kwargs):
canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
canonical_ids = tuple(id_ for _, id_ in canonical)
canonical_inputs = ','.join(input_ for input_, _ in canonical)
canonical_equation = alpha_canonicalize('{}->{}'.format(canonical_inputs, output))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to avoid format here?

Copy link
Contributor Author

@fritzo fritzo Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, 'blah {}'.format(u'\x80') fails in Python 2. I couldn't figure out how to write a test for this, but it fixed my error in an application.

canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output)
key = 'einsum', backend, canonical_equation, canonical_ids
return _memoize(key, einsum, equation, *operands, backend=backend)

Expand Down
11 changes: 4 additions & 7 deletions opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import pytest

from opt_einsum import contract, contract_path, helpers, contract_expression
from opt_einsum import compat, contract, contract_path, helpers, contract_expression

tests = [
# Test hadamard-like products
Expand Down Expand Up @@ -117,15 +117,14 @@ def test_drop_in_replacement(string):
assert np.allclose(opt, np.einsum(string, *views))


@pytest.mark.skipif(sys.version_info[0] < 3, reason='requires python3')
@pytest.mark.parametrize("string", tests)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I think there are two more in this file. I think this is good to go after those are patched?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I believe this is good to go now. Looking forward to using this in Pyro.

def test_compare_greek(string):
views = helpers.build_views(string)

ein = contract(string, *views, optimize=False, use_blas=False)

# convert to greek
string = ''.join(chr(ord(c) + 848) if c not in ',->.' else c for c in string)
string = ''.join(compat.get_chr(ord(c) + 848) if c not in ',->.' else c for c in string)

opt = contract(string, *views, optimize='greedy', use_blas=False)
assert np.allclose(ein, opt)
Expand All @@ -146,15 +145,14 @@ def test_compare_blas(string):
assert np.allclose(ein, opt)


@pytest.mark.skipif(sys.version_info[0] < 3, reason='requires python3')
@pytest.mark.parametrize("string", tests)
def test_compare_blas_greek(string):
views = helpers.build_views(string)

ein = contract(string, *views, optimize=False)

# convert to greek
string = ''.join(chr(ord(c) + 848) if c not in ',->.' else c for c in string)
string = ''.join(compat.get_chr(ord(c) + 848) if c not in ',->.' else c for c in string)

opt = contract(string, *views, optimize='greedy')
assert np.allclose(ein, opt)
Expand All @@ -163,10 +161,9 @@ def test_compare_blas_greek(string):
assert np.allclose(ein, opt)


@pytest.mark.skipif(sys.version_info[0] < 3, reason='requires python3')
def test_some_non_alphabet_maintains_order():
# 'c beta a' should automatically go to -> 'a c beta'
string = 'c' + chr(ord('b') + 848) + 'a'
string = 'c' + compat.get_chr(ord('b') + 848) + 'a'
# but beta will be temporarily replaced with 'b' for which 'cba->abc'
# so check manual output kicks in:
x = np.random.rand(2, 3, 4)
Expand Down
1 change: 0 additions & 1 deletion opt_einsum/tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def test_singleton_dimension_broadcast():
assert np.allclose(res2, np.full((1, 5), 5))


@pytest.mark.skipif(sys.version_info.major == 2, reason="Requires python 3.")
def test_large_int_input_format():
string = 'ab,bc,cd'
x, y, z = build_views(string)
Expand Down
13 changes: 13 additions & 0 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
the various path helper functions.
"""

import itertools

import numpy as np
import pytest

Expand Down Expand Up @@ -174,3 +176,14 @@ def test_can_optimize_outer_products():
a, b, c = [np.random.randn(10, 10) for _ in range(3)]
d = np.random.randn(10, 2)
assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, path='greedy')[0] == [(2, 3), (0, 2), (0, 1)]


@pytest.mark.parametrize('num_symbols', [2, 3, 26, 26 + 26, 256 - 140, 300])
def test_large_path(num_symbols):
symbols = ''.join(oe.get_symbol(i) for i in range(num_symbols))
dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4])))
expression = ','.join(symbols[t:t+2] for t in range(num_symbols - 1))
tensors = oe.helpers.build_views(expression, dimension_dict=dimension_dict)

# Check that path construction does not crash
oe.contract_path(expression, *tensors, path='greedy')