Skip to content

Commit

Permalink
move dataset test to basic
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Nov 29, 2024
1 parent 50dda90 commit 38c6786
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 20 deletions.
19 changes: 19 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,3 +964,22 @@ def test_no_copy_in_dataset_from_numpy_2d(rng, order, dtype):
else:
# makes a copy
assert not np.shares_memory(X, X1d)


def test_equal_datasets_from_row_major_and_col_major_data(tmp_path):
# row-major dataset
X_row, y = make_blobs(n_samples=1_000, n_features=1, centers=2)
assert X_row.flags["C_CONTIGUOUS"]
ds_row = lgb.Dataset(X_row, y)
ds_row_path = tmp_path / "ds_row.txt"
ds_row._dump_text(ds_row_path)

# col-major dataset
X_col = np.asfortranarray(X_row)
assert X_col.flags["F_CONTIGUOUS"]
ds_col = lgb.Dataset(X_col, y)
ds_col_path = tmp_path / "ds_col.txt"
ds_col._dump_text(ds_col_path)

# check datasets are equal
assert filecmp.cmp(ds_row_path, ds_col_path)
20 changes: 0 additions & 20 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# coding: utf-8
import copy
import filecmp
import itertools
import json
import math
Expand Down Expand Up @@ -4612,22 +4611,3 @@ def test_bagging_by_query_in_lambdarank():
ndcg_score_no_bagging_by_query = gbm_no_bagging_by_query.best_score["valid_0"]["ndcg@5"]
assert ndcg_score_bagging_by_query >= ndcg_score - 0.1
assert ndcg_score_no_bagging_by_query >= ndcg_score - 0.1


def test_equal_datasets_from_row_major_and_col_major_data(tmp_path):
# row-major dataset
X_row, y = make_synthetic_regression()
assert X_row.flags["C_CONTIGUOUS"]
ds_row = lgb.Dataset(X_row, y)
ds_row_path = tmp_path / "ds_row.txt"
ds_row._dump_text(ds_row_path)

# col-major dataset
X_col = np.asfortranarray(X_row)
assert X_col.flags["F_CONTIGUOUS"]
ds_col = lgb.Dataset(X_col, y)
ds_col_path = tmp_path / "ds_col.txt"
ds_col._dump_text(ds_col_path)

# check datasets are equal
assert filecmp.cmp(ds_row_path, ds_col_path)

0 comments on commit 38c6786

Please sign in to comment.