@@ -1153,6 +1153,7 @@ def _transform_data(self, data, name, train_name="train",
1153
1153
"" if name == train_name or not self .reuse_mean_and_std else
1154
1154
" with {} data" .format (train_name ),
1155
1155
))
1156
+ is_ndarray = isinstance (data , np .ndarray )
1156
1157
if refresh_redundant_info or self .whether_redundant is None :
1157
1158
self .whether_redundant = np .array ([
1158
1159
True if local_dict is None else False
@@ -1196,9 +1197,12 @@ def _transform_data(self, data, name, train_name="train",
1196
1197
line [j ] = local_dict ["nan" ]
1197
1198
else :
1198
1199
line [j ] = local_dict [elem ]
1199
- if whether_redundant is not None :
1200
+ if not is_ndarray and whether_redundant is not None :
1200
1201
data [i ] = [line [j ] for j in valid_indices ]
1201
- data = np .array (data , dtype = np .float32 )
1202
+ if is_ndarray and whether_redundant is not None :
1203
+ data = data [..., valid_indices ].astype (np .float32 )
1204
+ else :
1205
+ data = np .array (data , dtype = np .float32 )
1202
1206
if stage == 2 or stage == 3 :
1203
1207
data = np .asarray (data , dtype = np .float32 )
1204
1208
# Handle nan
@@ -1296,24 +1300,26 @@ def _load_data(self, data=None, numerical_idx=None, file_type="txt", names=("tra
1296
1300
n_train = None
1297
1301
else :
1298
1302
if data is None :
1303
+ is_ndarray = False
1299
1304
data , test_rate = self ._get_data_from_file (file_type , test_rate )
1300
1305
else :
1306
+ is_ndarray = True
1301
1307
if not isinstance (data , tuple ):
1302
1308
test_rate = 0
1303
- data = np .asarray (data , dtype = np .float32 ). tolist () # type: list
1309
+ data = np .asarray (data , dtype = np .float32 )
1304
1310
else :
1305
1311
data = tuple (
1306
1312
arr if isinstance (arr , list ) else
1307
- np .asarray (arr , np .float32 ). tolist () for arr in data
1313
+ np .asarray (arr , np .float32 ) for arr in data
1308
1314
)
1309
1315
if isinstance (data , tuple ):
1310
1316
if shuffle :
1311
- random .shuffle (data [0 ])
1317
+ np . random . shuffle ( data [ 0 ]) if is_ndarray else random .shuffle (data [0 ])
1312
1318
n_train = len (data [0 ])
1313
- data = data [0 ] + data [1 ]
1319
+ data = np . vstack ( data ) if is_ndarray else data [0 ] + data [1 ]
1314
1320
else :
1315
1321
if shuffle :
1316
- random .shuffle (data )
1322
+ np . random . shuffle ( data ) if is_ndarray else random .shuffle (data )
1317
1323
n_train = int (len (data ) * (1 - test_rate )) if test_rate > 0 else - 1
1318
1324
1319
1325
if not os .path .isdir (data_info_folder ):
0 commit comments