-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathtest_interface.py
106 lines (95 loc) · 2.99 KB
/
test_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import pytest
import cotengra as ctg
import numpy as np
@pytest.mark.parametrize("optimize_type", ["preset", "list", "tuple"])
def test_array_contract_path_cache(optimize_type):
if optimize_type == "preset":
optimize = "auto"
elif optimize_type == "list":
optimize = [(0, 1)] * 9
elif optimize_type == "tuple":
optimize = tuple([(0, 1)] * 9)
inputs, output, shapes, size_dict = ctg.utils.rand_equation(10, 3)
arrays = ctg.utils.make_arrays_from_inputs(inputs, size_dict)
pa = ctg.array_contract_path(
inputs, output, shapes=shapes, cache=True, optimize=optimize
)
pb = ctg.array_contract_path(
inputs, output, shapes=shapes, cache=True, optimize=optimize
)
pc = ctg.array_contract_path(
inputs, output, shapes=shapes, cache=False, optimize=optimize
)
assert pa is pb
assert (pb is not pc) or (optimize_type == "tuple")
eq = ctg.utils.inputs_output_to_eq(inputs, output)
xa = np.einsum(eq, *arrays)
xb = ctg.einsum(eq, *arrays, optimize=pa)
assert np.allclose(xa, xb)
@pytest.mark.parametrize("optimize_type", ["preset", "list", "tuple"])
def test_array_contract_expression_cache(optimize_type):
if optimize_type == "preset":
optimize = "auto"
elif optimize_type == "list":
optimize = [(0, 1)] * 9
elif optimize_type == "tuple":
optimize = tuple([(0, 1)] * 9)
inputs, output, shapes, size_dict = ctg.utils.rand_equation(10, 3)
arrays = ctg.utils.make_arrays_from_inputs(inputs, size_dict)
expra = ctg.array_contract_expression(
inputs,
output,
shapes=shapes,
cache=True,
optimize=optimize,
)
exprb = ctg.array_contract_expression(
inputs,
output,
shapes=shapes,
cache=True,
optimize=optimize,
)
exprc = ctg.array_contract_expression(
inputs,
output,
shapes=shapes,
cache=False,
optimize=optimize,
)
assert expra is exprb
assert exprb is not exprc
eq = ctg.utils.inputs_output_to_eq(inputs, output)
xa = np.einsum(eq, *arrays)
xb = expra(*arrays)
assert np.allclose(xa, xb)
xc = expra(*arrays)
assert np.allclose(xa, xc)
def test_einsum_formats_interleaved():
args = (
np.random.rand(2, 3, 4),
[2, 3, 4],
np.random.rand(4, 5, 6),
[4, 5, 6],
np.random.rand(2, 7),
[2, 7],
[7, 3],
)
x = np.einsum(*args)
y = ctg.einsum(*args)
assert np.allclose(x, y)
@pytest.mark.parametrize(
"eq,shapes",
[
("c...a,b...c->b...a", [(2, 5, 6, 3), (4, 6, 2)]),
("a...a->...", [(3, 3)]),
("a...a->...a", [(3, 4, 5, 3)]),
("...,...ab->ba...", [(), (2, 3, 4, 5)]),
("a,b,ab...c->b...a", [(2,), (3,), (2, 3, 4, 5, 6)]),
],
)
def test_einsum_ellipses(eq, shapes):
arrays = [np.random.rand(*shape) for shape in shapes]
x = np.einsum(eq, *arrays)
y = ctg.einsum(eq, *arrays)
assert np.allclose(x, y)