@@ -571,7 +571,7 @@ def save_checkpoint(self, folder):
571
571
572
572
def restore_checkpoint (self , folder ):
573
573
with self ._graph .as_default ():
574
- tf .train .Saver ().restore (self ._sess , tf . train . latest_checkpoint (folder ))
574
+ tf .train .Saver ().restore (self ._sess , os . path . join (folder , "Model" ))
575
575
576
576
# API
577
577
@@ -1007,7 +1007,7 @@ class AutoBase:
1007
1007
def __init__ (self , name = None , data_info = None , pre_process_settings = None , nan_handler_settings = None ,
1008
1008
* args , ** kwargs ):
1009
1009
if name is None :
1010
- raise ValueError ("name should be provided when using AutoMixin " )
1010
+ raise ValueError ("name should be provided when using AutoBase " )
1011
1011
self ._name = name
1012
1012
1013
1013
self ._data_folder = None
@@ -1377,12 +1377,6 @@ def _load_data(self, data=None, numerical_idx=None, file_type="txt", names=("tra
1377
1377
1378
1378
return x , y , x_test , y_test
1379
1379
1380
- def _define_py_collections (self ):
1381
- self .py_collections += [
1382
- "pre_process_settings" , "nan_handler_settings" ,
1383
- "_pre_processors" , "_nan_handler" , "transform_dicts"
1384
- ]
1385
-
1386
1380
def get_transformed_data_from_file (self , file , file_type = "txt" , include_label = False ):
1387
1381
x , _ = self ._get_data_from_file (file_type , 0 , file )
1388
1382
return self ._transform_data (x , "new" , include_label = include_label )
@@ -1664,17 +1658,49 @@ def _get_score(mean, std, sign):
1664
1658
return mean - std
1665
1659
return mean + std
1666
1660
1661
+ @staticmethod
1662
+ def _extract_info (dtype , info ):
1663
+ if dtype == "choice" :
1664
+ return info [0 ][random .randint (0 , len (info [0 ]) - 1 )]
1665
+ if len (info ) == 2 :
1666
+ floor , ceiling = info
1667
+ distribution = "linear"
1668
+ else :
1669
+ floor , ceiling , distribution = info
1670
+ if ceiling <= floor :
1671
+ raise ValueError ("ceiling should be greater than floor" )
1672
+ if dtype == "int" :
1673
+ return random .randint (floor , ceiling )
1674
+ if dtype == "float" :
1675
+ linear_target = floor + random .random () * (ceiling - floor )
1676
+ distribution_error_msg = "distribution '{}' not supported in range_search" .format (distribution )
1677
+ if distribution == "linear" :
1678
+ return linear_target
1679
+ if distribution [:3 ] == "log" :
1680
+ sign , log = int (linear_target > 0 ), math .log (math .fabs (linear_target ))
1681
+ if distribution == "log" :
1682
+ return sign * math .exp (log )
1683
+ if distribution == "log2" :
1684
+ return sign * 2 ** log
1685
+ if distribution == "log10" :
1686
+ return sign * 10 ** log
1687
+ raise NotImplementedError (distribution_error_msg )
1688
+ raise NotImplementedError (distribution_error_msg )
1689
+ raise NotImplementedError ("dtype '{}' not supported in range_search" .format (dtype ))
1690
+
1667
1691
def _select_parameter (self , params ):
1668
1692
scores = []
1669
1693
sign = Metrics .sign_dict [self ._metric_name ]
1670
1694
for i , param in enumerate (params ):
1671
1695
mean , std = self .mean_record [i ], self .std_record [i ]
1672
- train_mean , cv_mean , test_mean = mean
1673
- train_std , cv_std , test_std = std
1674
- if test_mean is None or test_std is None :
1696
+ if len ( mean ) == 2 :
1697
+ train_mean , cv_mean = mean
1698
+ train_std , cv_std = std
1675
1699
weighted_mean = 0.2 * train_mean + 0.8 * cv_mean
1676
1700
weighted_std = 0.2 * train_std + 0.8 * cv_std
1677
1701
else :
1702
+ train_mean , cv_mean , test_mean = mean
1703
+ train_std , cv_std , test_std = std
1678
1704
weighted_mean = 0.1 * train_mean + 0.2 * cv_mean + 0.7 * test_mean
1679
1705
weighted_std = 0.1 * train_std + 0.2 * cv_std + 0.7 * test_std
1680
1706
scores .append (self ._get_score (weighted_mean , weighted_std , sign ))
@@ -1824,25 +1850,9 @@ def get_param_by_range(self, param):
1824
1850
if not isinstance (dtype , str ) and isinstance (dtype , collections .Iterable ):
1825
1851
local_param_list = []
1826
1852
for local_dtype , local_info in zip (dtype , info ):
1827
- if local_dtype == "choice" :
1828
- local_param_list .append (np .random .choice (local_info [0 ], 1 )[0 ])
1829
- continue
1830
- floor , ceiling = local_info
1831
- if local_dtype == "int" :
1832
- local_param_list .append (random .randint (floor , ceiling ))
1833
- elif dtype == "float" :
1834
- local_param_list .append (floor + random .random () * (ceiling - floor ))
1835
- else :
1836
- raise NotImplementedError ("dtype '{}' not supported in range_search" .format (dtype ))
1853
+ local_param_list .append (self ._extract_info (local_dtype , local_info ))
1837
1854
return local_param_list
1838
- if dtype == "choice" :
1839
- return np .random .choice (info [0 ], 1 )[0 ]
1840
- floor , ceiling = info
1841
- if dtype == "int" :
1842
- return random .randint (floor , ceiling )
1843
- if dtype == "float" :
1844
- return floor + random .random () * (ceiling - floor )
1845
- raise NotImplementedError ("dtype '{}' not supported in range_search" .format (dtype ))
1855
+ return self ._extract_info (dtype , info )
1846
1856
1847
1857
def range_search (self , n , grid_params , switch_to_best_params = True ,
1848
1858
k = 3 , data = None , cv_rate = 0.1 , test_rate = 0. , sample_weights = None , ** kwargs ):
0 commit comments