Skip to content

Commit

Permalink
Fix relative imports after pip installation (mattloper#11)
Browse files Browse the repository at this point in the history
These namespace relative imports work in chumpy tests, but were causing import errors like this one:

```
  File "/home/circleci/repo/blmath/geometry/transform/test_composite.py", line 172, in test_forward_reverse_equivalence
    transform.rotate(np.array([1., 2., 3.]))
  File "/home/circleci/repo/blmath/geometry/transform/composite.py", line 237, in rotate
    from blmath.geometry.transform.rodrigues import as_rotation_matrix
  File "/home/circleci/repo/blmath/geometry/transform/rodrigues.py", line 2, in <module>
    import chumpy as ch
  File "/home/circleci/repo/venv/lib/python2.7/site-packages/chumpy/__init__.py", line 1, in <module>
    from .ch import *
  File "/home/circleci/repo/venv/lib/python2.7/site-packages/chumpy/ch.py", line 1310, in <module>
    from . import ch_ops
  File "/home/circleci/repo/venv/lib/python2.7/site-packages/chumpy/ch_ops.py", line 47, in <module>
    from . import ch
ImportError: cannot import name ch
```
# Conflicts:
#	Makefile
  • Loading branch information
paulmelnikow committed Sep 9, 2018
1 parent db6eaf8 commit b3a5953
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 80 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ tidy:

test: clean qtest
qtest: all
python -m unittest
# For some reason the import changes for Python 3 caused the Python 2 test
# loader to give up without loading any tests. So we discover them ourselves.
# python -m unittest
find chumpy -name 'test_*.py' | sed -e 's/\.py$$//' -e 's/\//./' | xargs python -m unittest

coverage: clean qcov
qcov: all
Expand Down
90 changes: 45 additions & 45 deletions chumpy/ch.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
import copy as external_copy
from functools import wraps
from scipy.sparse.linalg.interface import LinearOperator
from . import utils
from .utils import row, col
from .utils import row, col, timer, convert_inputs_to_sparse_if_necessary
import collections
from copy import deepcopy
from .utils import timer
from functools import reduce



# Turn this on if you want the profiler injected
DEBUG = False
# Turn this on to make optimizations very chatty for debugging
Expand Down Expand Up @@ -191,10 +190,10 @@ def sid(self):


def reshape(self, *args):
return reordering.reshape(a=self, newshape=args if len(args)>1 else args[0])
return reshape(a=self, newshape=args if len(args)>1 else args[0])

def ravel(self):
return reordering.reshape(a=self, newshape=(-1))
return reshape(a=self, newshape=(-1))

def __hash__(self):
return id(self)
Expand Down Expand Up @@ -294,7 +293,7 @@ def _compute_dr_wrt_sliced(self, wrt):
# What allows slicing.
if True:
inner = wrt
while issubclass(inner.__class__, reordering.Permute):
while issubclass(inner.__class__, Permute):
inner = inner.a
if inner is self:
return None
Expand Down Expand Up @@ -340,7 +339,7 @@ def add_dterm(self, dterm_name, dterm):
setattr(self, dterm_name, dterm)

def copy(self):
return ch_ops.copy(self)
return copy(self)

def __getstate__(self):
# Have to get rid of WeakKeyDictionaries for serialization
Expand Down Expand Up @@ -502,7 +501,7 @@ def __getitem__(self, key):
tmp = np.arange(np.prod(shape)).reshape(shape).__getitem__(key)
idxs = tmp.ravel()
newshape = tmp.shape
return reordering.Select(a=self, idxs=idxs, preferred_shape=newshape)
return Select(a=self, idxs=idxs, preferred_shape=newshape)

def __setitem__(self, key, value, itr=None):

Expand All @@ -527,7 +526,7 @@ def __setitem__(self, key, value, itr=None):
else:
inner = self
while not inner.is_ch_baseclass():
if issubclass(inner.__class__, reordering.Permute):
if issubclass(inner.__class__, Permute):
inner = inner.a
else:
raise Exception("Can't set array that is function of arrays.")
Expand Down Expand Up @@ -564,19 +563,19 @@ def on_changed(self, terms):

@property
def T(self):
return reordering.transpose(self)
return transpose(self)

def transpose(self, *axes):
return reordering.transpose(self, *axes)
return transpose(self, *axes)

def squeeze(self, axis=None):
return ch_ops.squeeze(self, axis)
return squeeze(self, axis)

def mean(self, axis=None):
return ch_ops.mean(self, axis=axis)
return mean(self, axis=axis)

def sum(self, axis=None):
return ch_ops.sum(self, axis=axis)
return sum(self, axis=axis)

def _call_on_changed(self):

Expand Down Expand Up @@ -630,7 +629,7 @@ def _superdot(self, lhs, rhs, profiler=None):

# TODO: Figure out how/whether to do this.
tm_maybe_sparse = timer()
lhs, rhs = utils.convert_inputs_to_sparse_if_necessary(lhs, rhs)
lhs, rhs = convert_inputs_to_sparse_if_necessary(lhs, rhs)
if tm_maybe_sparse() > 0.1:
pif('convert_inputs_to_sparse_if_necessary in {}sec'.format(tm_maybe_sparse()))

Expand Down Expand Up @@ -894,7 +893,7 @@ def string_for(self, my_name):
color = color_mapping[dtval._status] if hasattr(dtval, '_status') else 'grey'
if dtval == current_node:
color = 'blue'
if isinstance(dtval, reordering.Concatenate) and len(dtval.dr_cached) > 0:
if isinstance(dtval, Concatenate) and len(dtval.dr_cached) > 0:
s = 'dr_cached\n'
for k, v in dtval.dr_cached.items():
if v is not None:
Expand Down Expand Up @@ -1111,59 +1110,59 @@ def tree_iterator(self, visited=None, path=None):
yield node

def floor(self):
return ch_ops.floor(self)
return floor(self)

def ceil(self):
return ch_ops.ceil(self)
return ceil(self)

def dot(self, other):
return ch_ops.dot(self, other)
return dot(self, other)

def cumsum(self, axis=None):
return ch_ops.cumsum(a=self, axis=axis)
return cumsum(a=self, axis=axis)

def min(self, axis=None):
return ch_ops.amin(a=self, axis=axis)
return amin(a=self, axis=axis)

def max(self, axis=None):
return ch_ops.amax(a=self, axis=axis)
return amax(a=self, axis=axis)

########################################################
# Operator overloads

def __pos__(self): return self
def __neg__(self): return ch_ops.negative(self)
def __neg__(self): return negative(self)

def __add__ (self, other): return ch_ops.add(a=self, b=other)
def __radd__(self, other): return ch_ops.add(a=other, b=self)
def __add__ (self, other): return add(a=self, b=other)
def __radd__(self, other): return add(a=other, b=self)

def __sub__ (self, other): return ch_ops.subtract(a=self, b=other)
def __rsub__(self, other): return ch_ops.subtract(a=other, b=self)
def __sub__ (self, other): return subtract(a=self, b=other)
def __rsub__(self, other): return subtract(a=other, b=self)

def __mul__ (self, other): return ch_ops.multiply(a=self, b=other)
def __rmul__(self, other): return ch_ops.multiply(a=other, b=self)
def __mul__ (self, other): return multiply(a=self, b=other)
def __rmul__(self, other): return multiply(a=other, b=self)

def __div__ (self, other): return ch_ops.divide(x1=self, x2=other)
def __truediv__ (self, other): return ch_ops.divide(x1=self, x2=other)
def __rdiv__(self, other): return ch_ops.divide(x1=other, x2=self)
def __div__ (self, other): return divide(x1=self, x2=other)
def __truediv__ (self, other): return divide(x1=self, x2=other)
def __rdiv__(self, other): return divide(x1=other, x2=self)

def __pow__ (self, other): return ch_ops.power(x=self, pow=other)
def __rpow__(self, other): return ch_ops.power(x=other, pow=self)
def __pow__ (self, other): return power(x=self, pow=other)
def __rpow__(self, other): return power(x=other, pow=self)

def __rand__(self, other): return self.__and__(other)

def __abs__ (self): return ch_ops.abs(self)
def __abs__ (self): return abs(self)

def __gt__(self, other): return ch_ops.greater(self, other)
def __ge__(self, other): return ch_ops.greater_equal(self, other)
def __gt__(self, other): return greater(self, other)
def __ge__(self, other): return greater_equal(self, other)

def __lt__(self, other): return ch_ops.less(self, other)
def __le__(self, other): return ch_ops.less_equal(self, other)
def __lt__(self, other): return less(self, other)
def __le__(self, other): return less_equal(self, other)

def __ne__(self, other): return ch_ops.not_equal(self, other)
def __ne__(self, other): return not_equal(self, other)

# not added yet because of weak key dict conflicts
#def __eq__(self, other): return ch_ops.equal(self, other)
#def __eq__(self, other): return equal(self, other)


Ch._reserved_kw = set(Ch.__dict__.keys())
Expand Down Expand Up @@ -1307,13 +1306,14 @@ def compute_r(self):
def compute_dr_wrt(self, wrt):
return self._result.dr_wrt(wrt)

from . import ch_ops
from .ch_ops import *
__all__ += ch_ops.__all__
from .ch_ops import __all__ as all_ch_ops
__all__ += all_ch_ops

from . import reordering
from .reordering import *
__all__ += reordering.__all__
from .reordering import Permute
from .reordering import __all__ as all_reordering
__all__ += all_reordering


from . import linalg
Expand Down
Loading

0 comments on commit b3a5953

Please sign in to comment.