-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make get_symbol(large_number) work in Python2 #48
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Python 2/3 compatability shim | ||
|
||
try: | ||
# Python 2 | ||
get_chr = unichr | ||
strings = (str, type(get_chr(300))) | ||
except NameError: | ||
# Python 3 | ||
get_chr = chr | ||
strings = str |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import numpy as np | ||
import pytest | ||
|
||
from opt_einsum import contract, contract_path, helpers, contract_expression | ||
from opt_einsum import compat, contract, contract_path, helpers, contract_expression | ||
|
||
tests = [ | ||
# Test hadamard-like products | ||
|
@@ -117,15 +117,14 @@ def test_drop_in_replacement(string): | |
assert np.allclose(opt, np.einsum(string, *views)) | ||
|
||
|
||
@pytest.mark.skipif(sys.version_info[0] < 3, reason='requires python3') | ||
@pytest.mark.parametrize("string", tests) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I think there are two more in this file. I think this is good to go after those are patched? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, I believe this is good to go now. Looking forward to using this in Pyro. |
||
def test_compare_greek(string): | ||
views = helpers.build_views(string) | ||
|
||
ein = contract(string, *views, optimize=False, use_blas=False) | ||
|
||
# convert to greek | ||
string = ''.join(chr(ord(c) + 848) if c not in ',->.' else c for c in string) | ||
string = ''.join(compat.get_chr(ord(c) + 848) if c not in ',->.' else c for c in string) | ||
|
||
opt = contract(string, *views, optimize='greedy', use_blas=False) | ||
assert np.allclose(ein, opt) | ||
|
@@ -146,15 +145,14 @@ def test_compare_blas(string): | |
assert np.allclose(ein, opt) | ||
|
||
|
||
@pytest.mark.skipif(sys.version_info[0] < 3, reason='requires python3') | ||
@pytest.mark.parametrize("string", tests) | ||
def test_compare_blas_greek(string): | ||
views = helpers.build_views(string) | ||
|
||
ein = contract(string, *views, optimize=False) | ||
|
||
# convert to greek | ||
string = ''.join(chr(ord(c) + 848) if c not in ',->.' else c for c in string) | ||
string = ''.join(compat.get_chr(ord(c) + 848) if c not in ',->.' else c for c in string) | ||
|
||
opt = contract(string, *views, optimize='greedy') | ||
assert np.allclose(ein, opt) | ||
|
@@ -163,10 +161,9 @@ def test_compare_blas_greek(string): | |
assert np.allclose(ein, opt) | ||
|
||
|
||
@pytest.mark.skipif(sys.version_info[0] < 3, reason='requires python3') | ||
def test_some_non_alphabet_maintains_order(): | ||
# 'c beta a' should automatically go to -> 'a c beta' | ||
string = 'c' + chr(ord('b') + 848) + 'a' | ||
string = 'c' + compat.get_chr(ord('b') + 848) + 'a' | ||
# but beta will be temporarily replaced with 'b' for which 'cba->abc' | ||
# so check manual output kicks in: | ||
x = np.random.rand(2, 3, 4) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to avoid
format
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,
'blah {}'.format(u'\x80')
fails in Python 2. I couldn't figure out how to write a test for this, but it fixed my error in an application.