Skip to content

Commit 8e4cd1b

Browse files
committed
Update TestOpenML.py
1 parent 664be24 commit 8e4cd1b

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

_Dist/NeuralNetworks/_Tests/DistNN/TestOpenML.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@
2323
]
2424

2525

26-
def swap(arr, i1, i2):
27-
arr[..., i1], arr[..., i2] = arr[..., i2], arr[..., i1].copy()
28-
29-
3026
def download_data():
3127
data_folder = "_Data"
3228
idx_folder = os.path.join(data_folder, "idx")
@@ -41,15 +37,12 @@ def download_data():
4137
if os.path.isfile(data_file) and os.path.isfile(idx_file):
4238
continue
4339
dataset = datasets.get_dataset(idx)
44-
data, categorical_idx, names = dataset.get_data(
45-
return_categorical_indicator=True,
46-
return_attribute_names=True
47-
)
48-
data = data.toarray() if not isinstance(data, np.ndarray) else data
49-
target_idx = names.index(dataset.default_target_attribute)
40+
x, y, categorical_idx, names = dataset.get_data(
41+
target=dataset.default_target_attribute, dataset_format="array")
42+
categorical_idx.append(False)
43+
to_array = lambda arr: arr.toarray() if not isinstance(arr, np.ndarray) else arr
44+
data = np.hstack(list(map(to_array, [x, y.reshape([-1, 1])])))
5045
numerical_idx = ~np.array(categorical_idx)
51-
swap(numerical_idx, target_idx, -1)
52-
swap(data, target_idx, -1)
5346
with open(data_file, "w") as file:
5447
file.write("\n".join([" ".join(map(lambda n: str(n), line)) for line in data]))
5548
np.save(idx_file, numerical_idx)

0 commit comments

Comments
 (0)