Skip to content

Commit 8df86f7

Browse files
committed
Optimized _transform_data method
1 parent 316a3bc commit 8df86f7

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

_Dist/NeuralNetworks/Base.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,7 @@ def _transform_data(self, data, name, train_name="train",
11531153
"" if name == train_name or not self.reuse_mean_and_std else
11541154
" with {} data".format(train_name),
11551155
))
1156+
is_ndarray = isinstance(data, np.ndarray)
11561157
if refresh_redundant_info or self.whether_redundant is None:
11571158
self.whether_redundant = np.array([
11581159
True if local_dict is None else False
@@ -1196,9 +1197,12 @@ def _transform_data(self, data, name, train_name="train",
11961197
line[j] = local_dict["nan"]
11971198
else:
11981199
line[j] = local_dict[elem]
1199-
if whether_redundant is not None:
1200+
if not is_ndarray and whether_redundant is not None:
12001201
data[i] = [line[j] for j in valid_indices]
1201-
data = np.array(data, dtype=np.float32)
1202+
if is_ndarray and whether_redundant is not None:
1203+
data = data[..., valid_indices].astype(np.float32)
1204+
else:
1205+
data = np.array(data, dtype=np.float32)
12021206
if stage == 2 or stage == 3:
12031207
data = np.asarray(data, dtype=np.float32)
12041208
# Handle nan
@@ -1296,24 +1300,26 @@ def _load_data(self, data=None, numerical_idx=None, file_type="txt", names=("tra
12961300
n_train = None
12971301
else:
12981302
if data is None:
1303+
is_ndarray = False
12991304
data, test_rate = self._get_data_from_file(file_type, test_rate)
13001305
else:
1306+
is_ndarray = True
13011307
if not isinstance(data, tuple):
13021308
test_rate = 0
1303-
data = np.asarray(data, dtype=np.float32).tolist() # type: list
1309+
data = np.asarray(data, dtype=np.float32)
13041310
else:
13051311
data = tuple(
13061312
arr if isinstance(arr, list) else
1307-
np.asarray(arr, np.float32).tolist() for arr in data
1313+
np.asarray(arr, np.float32) for arr in data
13081314
)
13091315
if isinstance(data, tuple):
13101316
if shuffle:
1311-
random.shuffle(data[0])
1317+
np.random.shuffle(data[0]) if is_ndarray else random.shuffle(data[0])
13121318
n_train = len(data[0])
1313-
data = data[0] + data[1]
1319+
data = np.vstack(data) if is_ndarray else data[0] + data[1]
13141320
else:
13151321
if shuffle:
1316-
random.shuffle(data)
1322+
np.random.shuffle(data) if is_ndarray else random.shuffle(data)
13171323
n_train = int(len(data) * (1 - test_rate)) if test_rate > 0 else -1
13181324

13191325
if not os.path.isdir(data_info_folder):

0 commit comments

Comments
 (0)