Skip to content

Commit a7b89de

Browse files
committed
Update _Dist/NeuralNetworks
1 parent 9c6d822 commit a7b89de

File tree

3 files changed

+45
-31
lines changed

3 files changed

+45
-31
lines changed

_Dist/NeuralNetworks/Base.py

+39-29
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def save_checkpoint(self, folder):
571571

572572
def restore_checkpoint(self, folder):
573573
with self._graph.as_default():
574-
tf.train.Saver().restore(self._sess, tf.train.latest_checkpoint(folder))
574+
tf.train.Saver().restore(self._sess, os.path.join(folder, "Model"))
575575

576576
# API
577577

@@ -1007,7 +1007,7 @@ class AutoBase:
10071007
def __init__(self, name=None, data_info=None, pre_process_settings=None, nan_handler_settings=None,
10081008
*args, **kwargs):
10091009
if name is None:
1010-
raise ValueError("name should be provided when using AutoMixin")
1010+
raise ValueError("name should be provided when using AutoBase")
10111011
self._name = name
10121012

10131013
self._data_folder = None
@@ -1377,12 +1377,6 @@ def _load_data(self, data=None, numerical_idx=None, file_type="txt", names=("tra
13771377

13781378
return x, y, x_test, y_test
13791379

1380-
def _define_py_collections(self):
1381-
self.py_collections += [
1382-
"pre_process_settings", "nan_handler_settings",
1383-
"_pre_processors", "_nan_handler", "transform_dicts"
1384-
]
1385-
13861380
def get_transformed_data_from_file(self, file, file_type="txt", include_label=False):
13871381
x, _ = self._get_data_from_file(file_type, 0, file)
13881382
return self._transform_data(x, "new", include_label=include_label)
@@ -1664,17 +1658,49 @@ def _get_score(mean, std, sign):
16641658
return mean - std
16651659
return mean + std
16661660

1661+
@staticmethod
1662+
def _extract_info(dtype, info):
1663+
if dtype == "choice":
1664+
return info[0][random.randint(0, len(info[0]) - 1)]
1665+
if len(info) == 2:
1666+
floor, ceiling = info
1667+
distribution = "linear"
1668+
else:
1669+
floor, ceiling, distribution = info
1670+
if ceiling <= floor:
1671+
raise ValueError("ceiling should be greater than floor")
1672+
if dtype == "int":
1673+
return random.randint(floor, ceiling)
1674+
if dtype == "float":
1675+
linear_target = floor + random.random() * (ceiling - floor)
1676+
distribution_error_msg = "distribution '{}' not supported in range_search".format(distribution)
1677+
if distribution == "linear":
1678+
return linear_target
1679+
if distribution[:3] == "log":
1680+
sign, log = int(linear_target > 0), math.log(math.fabs(linear_target))
1681+
if distribution == "log":
1682+
return sign * math.exp(log)
1683+
if distribution == "log2":
1684+
return sign * 2 ** log
1685+
if distribution == "log10":
1686+
return sign * 10 ** log
1687+
raise NotImplementedError(distribution_error_msg)
1688+
raise NotImplementedError(distribution_error_msg)
1689+
raise NotImplementedError("dtype '{}' not supported in range_search".format(dtype))
1690+
16671691
def _select_parameter(self, params):
16681692
scores = []
16691693
sign = Metrics.sign_dict[self._metric_name]
16701694
for i, param in enumerate(params):
16711695
mean, std = self.mean_record[i], self.std_record[i]
1672-
train_mean, cv_mean, test_mean = mean
1673-
train_std, cv_std, test_std = std
1674-
if test_mean is None or test_std is None:
1696+
if len(mean) == 2:
1697+
train_mean, cv_mean = mean
1698+
train_std, cv_std = std
16751699
weighted_mean = 0.2 * train_mean + 0.8 * cv_mean
16761700
weighted_std = 0.2 * train_std + 0.8 * cv_std
16771701
else:
1702+
train_mean, cv_mean, test_mean = mean
1703+
train_std, cv_std, test_std = std
16781704
weighted_mean = 0.1 * train_mean + 0.2 * cv_mean + 0.7 * test_mean
16791705
weighted_std = 0.1 * train_std + 0.2 * cv_std + 0.7 * test_std
16801706
scores.append(self._get_score(weighted_mean, weighted_std, sign))
@@ -1824,25 +1850,9 @@ def get_param_by_range(self, param):
18241850
if not isinstance(dtype, str) and isinstance(dtype, collections.Iterable):
18251851
local_param_list = []
18261852
for local_dtype, local_info in zip(dtype, info):
1827-
if local_dtype == "choice":
1828-
local_param_list.append(np.random.choice(local_info[0], 1)[0])
1829-
continue
1830-
floor, ceiling = local_info
1831-
if local_dtype == "int":
1832-
local_param_list.append(random.randint(floor, ceiling))
1833-
elif dtype == "float":
1834-
local_param_list.append(floor + random.random() * (ceiling - floor))
1835-
else:
1836-
raise NotImplementedError("dtype '{}' not supported in range_search".format(dtype))
1853+
local_param_list.append(self._extract_info(local_dtype, local_info))
18371854
return local_param_list
1838-
if dtype == "choice":
1839-
return np.random.choice(info[0], 1)[0]
1840-
floor, ceiling = info
1841-
if dtype == "int":
1842-
return random.randint(floor, ceiling)
1843-
if dtype == "float":
1844-
return floor + random.random() * (ceiling - floor)
1845-
raise NotImplementedError("dtype '{}' not supported in range_search".format(dtype))
1855+
return self._extract_info(dtype, info)
18461856

18471857
def range_search(self, n, grid_params, switch_to_best_params=True,
18481858
k=3, data=None, cv_rate=0.1, test_rate=0., sample_weights=None, **kwargs):

_Dist/NeuralNetworks/e_AdvancedNN/NN.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def init_model_structure_settings(self):
9797
def _get_embedding(self, i, n):
9898
embedding_size = math.ceil(math.log2(n)) + 1 if self.embedding_size == "log" else self.embedding_size
9999
embedding = tf.Variable(tf.truncated_normal(
100-
[n, embedding_size], mean=0, stddev=0.02
100+
[1, embedding_size], mean=0, stddev=0.02
101101
), name="Embedding{}".format(i))
102102
return tf.nn.embedding_lookup(embedding, self._categorical_xs[i], name="Embedded_X{}".format(i))
103103

_Dist/NeuralNetworks/g_DistNN/NN.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
sys.path.append(root_path)
66

77
from _Dist.NeuralNetworks.Base import DistMixin
8-
from _Dist.NeuralNetworks.f_AutoNN.NN import AutoAdvanced
8+
from _Dist.NeuralNetworks.f_AutoNN.NN import AutoBasic, AutoAdvanced
9+
10+
11+
class DistBasic(AutoBasic, DistMixin):
12+
pass
913

1014

1115
class DistAdvanced(AutoAdvanced, DistMixin):

0 commit comments

Comments
 (0)