@@ -121,16 +121,18 @@ def __repr__(self):
121
121
kwlist .append (("categories" , self .categories ))
122
122
repr_pretty_impl (p , self , [], kwlist )
123
123
124
- def __eq__ (self , other ):
125
- return self .__dict__ == other .__dict__
124
+ def __getstate__ (self ):
125
+ return {'version' : 0 , 'factor' : self .factor , 'type' : self .type ,
126
+ 'state' : self .state , 'num_columns' : self .num_columns ,
127
+ 'categories' : self .categories }
126
128
127
- def __hash__ (self ):
128
- if not self .categories :
129
- categories = 'NoCategories'
130
- else :
131
- categories = frozenset ( self . categories )
132
- return hash (( FactorInfo , str ( self .factor ), str ( self . type ),
133
- str ( self .state ), str ( self . num_columns ), categories ))
129
+ def __setstate__ (self , pickle ):
130
+ check_pickle_version ( pickle [ 'version' ], 0 , self .__class__ . __name__ )
131
+ self . factor = pickle [ 'factor' ]
132
+ self . type = pickle [ 'type' ]
133
+ self . state = pickle [ 'state' ]
134
+ self .num_columns = pickle [ 'num_columns' ]
135
+ self .categories = pickle [ ' categories' ]
134
136
135
137
136
138
def test_FactorInfo ():
@@ -245,10 +247,17 @@ def _repr_pretty_(self, p, cycle):
245
247
("contrast_matrices" , self .contrast_matrices ),
246
248
("num_columns" , self .num_columns )])
247
249
248
- def __eq__ (self , other ):
249
- return self .__dict__ == other .__dict__
250
+ def __getstate__ (self ):
251
+ return {'version' : 0 , 'factors' : self .factors ,
252
+ 'contrast_matrices' : self .contrast_matrices ,
253
+ 'num_columns' : self .num_columns }
254
+
255
+ def __setstate__ (self , pickle ):
256
+ check_pickle_version (pickle ['version' ], 0 , self .__class__ .__name__ )
257
+ self .factors = pickle ['factors' ]
258
+ self .contrast_matrices = pickle ['contrast_matrices' ]
259
+ self .num_columns = pickle ['num_columns' ]
250
260
251
- # __getstate__ = no_pickling
252
261
253
262
def test_SubtermInfo ():
254
263
cm = ContrastMatrix (np .ones ((2 , 2 )), ["[1]" , "[2]" ])
@@ -706,21 +715,19 @@ def from_array(cls, array_like, default_column_prefix="column"):
706
715
return DesignInfo (column_names )
707
716
708
717
def __getstate__ (self ):
709
- return (0 , self .column_name_indexes , self .factor_infos ,
710
- self .term_codings , self .term_slices , self .term_name_slices )
718
+ return {'version' : 0 , 'column_name_indexes' : self .column_name_indexes ,
719
+ 'factor_infos' : self .factor_infos ,
720
+ 'term_codings' : self .term_codings ,
721
+ 'term_slices' : self .term_slices ,
722
+ 'term_name_slices' : self .term_name_slices }
711
723
712
724
def __setstate__ (self , pickle ):
713
- (version , column_name_indexes , factor_infos , term_codings ,
714
- term_slices , term_name_slices ) = pickle
715
- check_pickle_version (version , 0 , self .__class__ .__name__ )
716
- self .column_name_indexes = column_name_indexes
717
- self .factor_infos = factor_infos
718
- self .term_codings = term_codings
719
- self .term_slices = term_slices
720
- self .term_name_slices = term_name_slices
721
-
722
- def __eq__ (self , other ):
723
- return self .__dict__ == other .__dict__
725
+ check_pickle_version (pickle ['version' ], 0 , self .__class__ .__name__ )
726
+ self .column_name_indexes = pickle ['column_name_indexes' ]
727
+ self .factor_infos = pickle ['factor_infos' ]
728
+ self .term_codings = pickle ['term_codings' ]
729
+ self .term_slices = pickle ['term_slices' ]
730
+ self .term_name_slices = pickle ['term_name_slices' ]
724
731
725
732
726
733
class _MockFactor (object ):
@@ -772,9 +779,12 @@ def test_DesignInfo():
772
779
773
780
# smoke test
774
781
repr (di )
775
- from six .moves import cPickle as pickle
776
782
777
- assert di == pickle .loads (pickle .dumps (di , pickle .HIGHEST_PROTOCOL ))
783
+ # Pickling check
784
+ from six .moves import cPickle as pickle
785
+ from patsy .util import assert_pickled_equals
786
+ di2 = pickle .loads (pickle .dumps (di , pickle .HIGHEST_PROTOCOL ))
787
+ assert_pickled_equals (di , di2 )
778
788
779
789
# One without term objects
780
790
di = DesignInfo (["a1" , "a2" , "a3" , "b" ])
@@ -795,7 +805,8 @@ def test_DesignInfo():
795
805
assert di .slice ("a3" ) == slice (2 , 3 )
796
806
assert di .slice ("b" ) == slice (3 , 4 )
797
807
798
- assert di == pickle .loads (pickle .dumps (di , pickle .HIGHEST_PROTOCOL ))
808
+ di2 = pickle .loads (pickle .dumps (di , pickle .HIGHEST_PROTOCOL ))
809
+ assert_pickled_equals (di , di2 )
799
810
800
811
# Check intercept handling in describe()
801
812
assert DesignInfo (["Intercept" , "a" , "b" ]).describe () == "1 + a + b"
0 commit comments