Skip to content

Commit 5704223

Browse files
committed
Fix ch.tensordot for numpy 1.14+ (#2)
1 parent 58a76a8 commit 5704223

File tree

4 files changed

+239
-8
lines changed

4 files changed

+239
-8
lines changed

.travis.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ before_install:
1010
install:
1111
- pip install --upgrade pip
1212
- travis_wait pip install -r requirements.txt
13-
script: make test
13+
script:
14+
- python -c 'import numpy; print numpy.version.version'
15+
- make test

chumpy/ch_ops.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -793,13 +793,13 @@ def nonzero(a):
793793
a = a.r
794794
return np.nonzero(a)
795795

796-
# Try to pull the code for tensordot in from numpy and reinterpret it using chumpy ops
797-
try:
798-
import inspect
799-
exec(''.join(inspect.getsourcelines(np.tensordot)[0]))
800-
__all__ += ['tensordot']
801-
except:
802-
pass
796+
# Pull the code for tensordot in from numpy and reinterpret it using chumpy ops
797+
import os
798+
source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'np_tensordot.py')
799+
with open(source_path, 'r') as f:
800+
source_lines = f.readlines()
801+
exec(''.join(source_lines))
802+
__all__ += ['tensordot']
803803

804804

805805

chumpy/np_tensordot.py

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Up to numpy 1.13, the numpy implementation of tensordot could be
2+
# reinterpreted using chumpy. With numpy 1.14 the implementation started using
3+
# ufunc.multiply.reduce which can't be understood by chumpy. This is the
4+
# chumpy-compatible implementation of tensodrot from numpy 1.13.3.
5+
#
6+
# i.e.
7+
#
8+
# import inspect
9+
# with open('np_tensordot.py', 'w') as f:
10+
# f.write(''.join(inspect.getsourcelines(np.tensordot)[0]))
11+
12+
"""
13+
Copyright (c) 2005-2017, NumPy Developers.
14+
All rights reserved.
15+
16+
Redistribution and use in source and binary forms, with or without
17+
modification, are permitted provided that the following conditions are
18+
met:
19+
20+
* Redistributions of source code must retain the above copyright
21+
notice, this list of conditions and the following disclaimer.
22+
23+
* Redistributions in binary form must reproduce the above
24+
copyright notice, this list of conditions and the following
25+
disclaimer in the documentation and/or other materials provided
26+
with the distribution.
27+
28+
* Neither the name of the NumPy Developers nor the names of any
29+
contributors may be used to endorse or promote products derived
30+
from this software without specific prior written permission.
31+
32+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
33+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
35+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
36+
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
37+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
38+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
39+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
40+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
41+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
42+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43+
"""
44+
45+
def tensordot(a, b, axes=2):
46+
"""
47+
Compute tensor dot product along specified axes for arrays >= 1-D.
48+
49+
Given two tensors (arrays of dimension greater than or equal to one),
50+
`a` and `b`, and an array_like object containing two array_like
51+
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
52+
elements (components) over the axes specified by ``a_axes`` and
53+
``b_axes``. The third argument can be a single non-negative
54+
integer_like scalar, ``N``; if it is such, then the last ``N``
55+
dimensions of `a` and the first ``N`` dimensions of `b` are summed
56+
over.
57+
58+
Parameters
59+
----------
60+
a, b : array_like, len(shape) >= 1
61+
Tensors to "dot".
62+
63+
axes : int or (2,) array_like
64+
* integer_like
65+
If an int N, sum over the last N axes of `a` and the first N axes
66+
of `b` in order. The sizes of the corresponding axes must match.
67+
* (2,) array_like
68+
Or, a list of axes to be summed over, first sequence applying to `a`,
69+
second to `b`. Both elements array_like must be of the same length.
70+
71+
See Also
72+
--------
73+
dot, einsum
74+
75+
Notes
76+
-----
77+
Three common use cases are:
78+
* ``axes = 0`` : tensor product :math:`a\\otimes b`
79+
* ``axes = 1`` : tensor dot product :math:`a\\cdot b`
80+
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
81+
82+
When `axes` is integer_like, the sequence for evaluation will be: first
83+
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
84+
Nth axis in `b` last.
85+
86+
When there is more than one axis to sum over - and they are not the last
87+
(first) axes of `a` (`b`) - the argument `axes` should consist of
88+
two sequences of the same length, with the first axis to sum over given
89+
first in both sequences, the second axis second, and so forth.
90+
91+
Examples
92+
--------
93+
A "traditional" example:
94+
95+
>>> a = np.arange(60.).reshape(3,4,5)
96+
>>> b = np.arange(24.).reshape(4,3,2)
97+
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
98+
>>> c.shape
99+
(5, 2)
100+
>>> c
101+
array([[ 4400., 4730.],
102+
[ 4532., 4874.],
103+
[ 4664., 5018.],
104+
[ 4796., 5162.],
105+
[ 4928., 5306.]])
106+
>>> # A slower but equivalent way of computing the same...
107+
>>> d = np.zeros((5,2))
108+
>>> for i in range(5):
109+
... for j in range(2):
110+
... for k in range(3):
111+
... for n in range(4):
112+
... d[i,j] += a[k,n,i] * b[n,k,j]
113+
>>> c == d
114+
array([[ True, True],
115+
[ True, True],
116+
[ True, True],
117+
[ True, True],
118+
[ True, True]], dtype=bool)
119+
120+
An extended example taking advantage of the overloading of + and \\*:
121+
122+
>>> a = np.array(range(1, 9))
123+
>>> a.shape = (2, 2, 2)
124+
>>> A = np.array(('a', 'b', 'c', 'd'), dtype=object)
125+
>>> A.shape = (2, 2)
126+
>>> a; A
127+
array([[[1, 2],
128+
[3, 4]],
129+
[[5, 6],
130+
[7, 8]]])
131+
array([[a, b],
132+
[c, d]], dtype=object)
133+
134+
>>> np.tensordot(a, A) # third argument default is 2 for double-contraction
135+
array([abbcccdddd, aaaaabbbbbbcccccccdddddddd], dtype=object)
136+
137+
>>> np.tensordot(a, A, 1)
138+
array([[[acc, bdd],
139+
[aaacccc, bbbdddd]],
140+
[[aaaaacccccc, bbbbbdddddd],
141+
[aaaaaaacccccccc, bbbbbbbdddddddd]]], dtype=object)
142+
143+
>>> np.tensordot(a, A, 0) # tensor product (result too long to incl.)
144+
array([[[[[a, b],
145+
[c, d]],
146+
...
147+
148+
>>> np.tensordot(a, A, (0, 1))
149+
array([[[abbbbb, cddddd],
150+
[aabbbbbb, ccdddddd]],
151+
[[aaabbbbbbb, cccddddddd],
152+
[aaaabbbbbbbb, ccccdddddddd]]], dtype=object)
153+
154+
>>> np.tensordot(a, A, (2, 1))
155+
array([[[abb, cdd],
156+
[aaabbbb, cccdddd]],
157+
[[aaaaabbbbbb, cccccdddddd],
158+
[aaaaaaabbbbbbbb, cccccccdddddddd]]], dtype=object)
159+
160+
>>> np.tensordot(a, A, ((0, 1), (0, 1)))
161+
array([abbbcccccddddddd, aabbbbccccccdddddddd], dtype=object)
162+
163+
>>> np.tensordot(a, A, ((2, 1), (1, 0)))
164+
array([acccbbdddd, aaaaacccccccbbbbbbdddddddd], dtype=object)
165+
166+
"""
167+
try:
168+
iter(axes)
169+
except:
170+
axes_a = list(range(-axes, 0))
171+
axes_b = list(range(0, axes))
172+
else:
173+
axes_a, axes_b = axes
174+
try:
175+
na = len(axes_a)
176+
axes_a = list(axes_a)
177+
except TypeError:
178+
axes_a = [axes_a]
179+
na = 1
180+
try:
181+
nb = len(axes_b)
182+
axes_b = list(axes_b)
183+
except TypeError:
184+
axes_b = [axes_b]
185+
nb = 1
186+
187+
a, b = asarray(a), asarray(b)
188+
as_ = a.shape
189+
nda = a.ndim
190+
bs = b.shape
191+
ndb = b.ndim
192+
equal = True
193+
if na != nb:
194+
equal = False
195+
else:
196+
for k in range(na):
197+
if as_[axes_a[k]] != bs[axes_b[k]]:
198+
equal = False
199+
break
200+
if axes_a[k] < 0:
201+
axes_a[k] += nda
202+
if axes_b[k] < 0:
203+
axes_b[k] += ndb
204+
if not equal:
205+
raise ValueError("shape-mismatch for sum")
206+
207+
# Move the axes to sum over to the end of "a"
208+
# and to the front of "b"
209+
notin = [k for k in range(nda) if k not in axes_a]
210+
newaxes_a = notin + axes_a
211+
N2 = 1
212+
for axis in axes_a:
213+
N2 *= as_[axis]
214+
newshape_a = (-1, N2)
215+
olda = [as_[axis] for axis in notin]
216+
217+
notin = [k for k in range(ndb) if k not in axes_b]
218+
newaxes_b = axes_b + notin
219+
N2 = 1
220+
for axis in axes_b:
221+
N2 *= bs[axis]
222+
newshape_b = (N2, -1)
223+
oldb = [bs[axis] for axis in notin]
224+
225+
at = a.transpose(newaxes_a).reshape(newshape_a)
226+
bt = b.transpose(newaxes_b).reshape(newshape_b)
227+
res = dot(at, bt)
228+
return res.reshape(olda + oldb)

circle.yml

+1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ dependencies:
1515

1616
test:
1717
override:
18+
- python -c 'import numpy; print numpy.version.version'
1819
- make test

0 commit comments

Comments
 (0)