Skip to content

Commit 19ad339

Browse files
committed
Added logic to handle modules and user-defined functions
1 parent 460a6f9 commit 19ad339

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

patsy/design_info.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def var_names(self, eval_env=0):
685685
else:
686686
return {}
687687

688-
def partial(self, columns, product=False):
688+
def partial(self, columns, product=False, eval_env=0):
689689
"""Returns a partial prediction array where only the variables in the
690690
dict ``columns`` are tranformed per the :class:`DesignInfo`
691691
transformations. The terms that are not influenced by ``columns``
@@ -703,6 +703,18 @@ def partial(self, columns, product=False):
703703
:returns: A numpy array of the partial design matrix.
704704
"""
705705
from .highlevel import dmatrix
706+
from types import ModuleType
707+
708+
if not eval_env:
709+
from patsy.eval import EvalEnvironment
710+
eval_env = EvalEnvironment.capture(eval_env, reference=1)
711+
712+
# We need to get rid of the non-callable items from the eval_env
713+
namespaces = [{key: value} for ns in eval_env._namespaces
714+
for key, value in six.iteritems(ns)
715+
if callable(value) or isinstance(value, ModuleType)]
716+
eval_env._namespaces = namespaces
717+
706718
if product:
707719
columns = _column_product(columns)
708720
rows = None
@@ -712,7 +724,7 @@ def partial(self, columns, product=False):
712724
rows = len(columns[col])
713725
parts = []
714726
for term, subterm in six.iteritems(self.term_codings):
715-
term_vars = term.var_names()
727+
term_vars = term.var_names(eval_env)
716728
present = True
717729
for term_var in term_vars:
718730
if term_var not in columns:
@@ -1312,6 +1324,16 @@ def test_DesignInfo_partial():
13121324
assert_raises(ValueError, dm.design_info.partial, {'a': ['a', 'b'],
13131325
'b': [1, 2, 3]})
13141326

1327+
def some_function(x):
1328+
return np.where(x > 2, 1, 2)
1329+
1330+
dm = dmatrix('1 + some_function(c)')
1331+
x = np.array([[0, 2],
1332+
[0, 2],
1333+
[0, 1]])
1334+
y = dm.design_info.partial({'c': np.array([1, 2, 3])})
1335+
assert_allclose(x, y)
1336+
13151337

13161338
def _column_product(columns):
13171339
from itertools import product

0 commit comments

Comments
 (0)