@@ -685,7 +685,7 @@ def var_names(self, eval_env=0):
685
685
else :
686
686
return {}
687
687
688
- def partial (self , columns , product = False ):
688
+ def partial (self , columns , product = False , eval_env = 0 ):
689
689
"""Returns a partial prediction array where only the variables in the
690
690
dict ``columns`` are tranformed per the :class:`DesignInfo`
691
691
transformations. The terms that are not influenced by ``columns``
@@ -703,6 +703,18 @@ def partial(self, columns, product=False):
703
703
:returns: A numpy array of the partial design matrix.
704
704
"""
705
705
from .highlevel import dmatrix
706
+ from types import ModuleType
707
+
708
+ if not eval_env :
709
+ from patsy .eval import EvalEnvironment
710
+ eval_env = EvalEnvironment .capture (eval_env , reference = 1 )
711
+
712
+ # We need to get rid of the non-callable items from the eval_env
713
+ namespaces = [{key : value } for ns in eval_env ._namespaces
714
+ for key , value in six .iteritems (ns )
715
+ if callable (value ) or isinstance (value , ModuleType )]
716
+ eval_env ._namespaces = namespaces
717
+
706
718
if product :
707
719
columns = _column_product (columns )
708
720
rows = None
@@ -712,7 +724,7 @@ def partial(self, columns, product=False):
712
724
rows = len (columns [col ])
713
725
parts = []
714
726
for term , subterm in six .iteritems (self .term_codings ):
715
- term_vars = term .var_names ()
727
+ term_vars = term .var_names (eval_env )
716
728
present = True
717
729
for term_var in term_vars :
718
730
if term_var not in columns :
@@ -1312,6 +1324,16 @@ def test_DesignInfo_partial():
1312
1324
assert_raises (ValueError , dm .design_info .partial , {'a' : ['a' , 'b' ],
1313
1325
'b' : [1 , 2 , 3 ]})
1314
1326
1327
+ def some_function (x ):
1328
+ return np .where (x > 2 , 1 , 2 )
1329
+
1330
+ dm = dmatrix ('1 + some_function(c)' )
1331
+ x = np .array ([[0 , 2 ],
1332
+ [0 , 2 ],
1333
+ [0 , 1 ]])
1334
+ y = dm .design_info .partial ({'c' : np .array ([1 , 2 , 3 ])})
1335
+ assert_allclose (x , y )
1336
+
1315
1337
1316
1338
def _column_product (columns ):
1317
1339
from itertools import product
0 commit comments