Skip to content

Commit 2d11bfe

Browse files
committed
Update _Dist/NeuralNetworks
1 parent 61ffdfb commit 2d11bfe

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

_Dist/NeuralNetworks/Base.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,10 @@ def gen_batch(self, n_batch, re_shuffle=True):
133133
if next_cursor >= self.n_valid:
134134
next_cursor = self.n_valid
135135
end = True
136-
rs, w = self._get_data(indices[self._batch_cursor:next_cursor])
137-
if end:
138-
self._batch_cursor = -1
139-
else:
140-
self._batch_cursor = next_cursor
136+
data, w = self._get_data(indices[self._batch_cursor:next_cursor])
137+
self._batch_cursor = -1 if end else next_cursor
141138
logger.debug("Done")
142-
return rs, w
139+
return data, w
143140

144141
def gen_random_subset(self, n):
145142
n = min(n, self.n_valid)

_Dist/NeuralNetworks/NNUtil.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,8 @@ def get_feature_info(data, numerical_idx, is_regression, logger=None):
313313
else:
314314
shrink_features = data_t
315315
feature_sets = [
316-
set() if idx is None or idx else set(shrink_feat)
317-
for idx, shrink_feat in zip(numerical_idx, shrink_features)
316+
set() if idx is None or idx else set(shrink_feature)
317+
for idx, shrink_feature in zip(numerical_idx, shrink_features)
318318
]
319319
n_features = [len(feature_set) for feature_set in feature_sets]
320320
all_num_idx = [
@@ -349,9 +349,7 @@ def get_feature_info(data, numerical_idx, is_regression, logger=None):
349349
all_num_idx[i] = numerical_idx[i] = None
350350
elif numerical_idx[i]:
351351
shrink_feature = np.asarray(shrink_feature, np.float32)
352-
if np.isnan(shrink_feature[-1]):
353-
shrink_feature = shrink_feature[:-1]
354-
if np.max(shrink_feature) < 2 ** 30:
352+
if np.max(shrink_feature[~np.isnan(shrink_feature)]) < 2 ** 30:
355353
if np.allclose(shrink_feature, np.array(shrink_feature, np.int32)):
356354
if Toolbox.all_unique(shrink_feature):
357355
Toolbox.warn_all_unique(i, logger)

0 commit comments

Comments
 (0)