Skip to content

Commit 825c07e

Browse files
committed
Add unit test for data validation
1 parent 1de976a commit 825c07e

22 files changed

+1074
-104
lines changed

tests/unit_tests/data_validation/test_ClassImbalance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import pandas as pd
33
import validmind as vm
4+
from validmind.errors import SkipTestError
45
from validmind.tests.data_validation.ClassImbalance import ClassImbalance
56
from plotly.graph_objs import Figure
67

@@ -79,5 +80,5 @@ def test_missing_target(self):
7980
__log=False,
8081
)
8182

82-
with self.assertRaises(Exception):
83+
with self.assertRaises(SkipTestError):
8384
ClassImbalance(dataset_no_target)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import unittest
2+
import pandas as pd
3+
import validmind as vm
4+
import matplotlib.pyplot as plt
5+
from validmind.tests.data_validation.RollingStatsPlot import RollingStatsPlot
6+
7+
8+
class TestRollingStatsPlot(unittest.TestCase):
9+
def setUp(self):
10+
# Create a sample time series dataset
11+
dates = pd.date_range(start="2023-01-01", periods=100, freq="D")
12+
df = pd.DataFrame(
13+
{"A": range(100), "B": [i * 2 for i in range(100)]}, index=dates
14+
)
15+
16+
self.vm_dataset = vm.init_dataset(
17+
input_id="test_dataset", dataset=df, feature_columns=["A", "B"], __log=False
18+
)
19+
20+
# Create a dataset without datetime index
21+
df_no_datetime = pd.DataFrame(
22+
{"A": range(100), "B": [i * 2 for i in range(100)]}
23+
)
24+
25+
self.vm_dataset_no_datetime = vm.init_dataset(
26+
input_id="test_dataset_no_datetime",
27+
dataset=df_no_datetime,
28+
feature_columns=["A", "B"],
29+
__log=False,
30+
)
31+
32+
def test_rolling_stats_plot(self):
33+
figures = RollingStatsPlot(self.vm_dataset, window_size=10)
34+
35+
# Check that we get the correct number of figures (one per feature)
36+
self.assertEqual(len(figures), 2)
37+
38+
# Check that outputs are matplotlib figures
39+
for fig in figures:
40+
self.assertIsInstance(fig, plt.Figure)
41+
42+
# Clean up
43+
plt.close("all")
44+
45+
def test_no_datetime_index(self):
46+
# Should raise an error for non-datetime index
47+
with self.assertRaises(Exception) as context:
48+
RollingStatsPlot(self.vm_dataset_no_datetime)
49+
50+
# Verify error message mentions datetime requirement
51+
self.assertIn("datetime", str(context.exception).lower())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import unittest
2+
import pandas as pd
3+
import numpy as np
4+
import validmind as vm
5+
import plotly.graph_objects as go
6+
from validmind.tests.data_validation.SeasonalDecompose import SeasonalDecompose
7+
from validmind.errors import SkipTestError
8+
9+
10+
class TestSeasonalDecompose(unittest.TestCase):
11+
def setUp(self):
12+
# Create a sample time series dataset with seasonal pattern
13+
dates = pd.date_range(start="2023-01-01", periods=100, freq="D")
14+
seasonal_pattern = np.sin(np.linspace(0, 4 * np.pi, 100)) # 2 complete cycles
15+
trend = np.linspace(0, 2, 100) # upward trend
16+
noise = np.random.normal(0, 0.1, 100)
17+
18+
df = pd.DataFrame(
19+
{
20+
"feature1": seasonal_pattern + trend + noise,
21+
"feature2": seasonal_pattern * 2 + trend + noise,
22+
},
23+
index=dates,
24+
)
25+
26+
self.vm_dataset = vm.init_dataset(
27+
input_id="test_dataset",
28+
dataset=df,
29+
feature_columns=["feature1", "feature2"],
30+
__log=False,
31+
)
32+
33+
# Create dataset with non-finite values
34+
df_with_nan = df.copy()
35+
df_with_nan.iloc[0:10, 0] = np.nan
36+
self.vm_dataset_with_nan = vm.init_dataset(
37+
input_id="test_dataset_with_nan",
38+
dataset=df_with_nan,
39+
feature_columns=["feature1", "feature2"],
40+
__log=False,
41+
)
42+
43+
def test_seasonal_decompose(self):
44+
figures = SeasonalDecompose(self.vm_dataset)
45+
46+
# Check that we get the correct number of figures (one per feature)
47+
self.assertIsInstance(figures, tuple)
48+
self.assertEqual(len(figures), 2)
49+
50+
# Check that outputs are plotly figures with correct subplots
51+
for fig in figures:
52+
self.assertIsInstance(fig, go.Figure)
53+
# Should have 6 subplots: Observed, Trend, Seasonal, Residuals,
54+
# Histogram, and Q-Q plot
55+
self.assertEqual(len(fig.data), 7) # 6 plots + 1 QQ line
56+
57+
def test_seasonal_decompose_with_nan(self):
58+
# Should still work with NaN values
59+
figures = SeasonalDecompose(self.vm_dataset_with_nan)
60+
self.assertEqual(len(figures), 2)
61+
62+
def test_seasonal_decompose_models(self):
63+
# Test additive model (should work with any data)
64+
figures_add = SeasonalDecompose(self.vm_dataset, seasonal_model="additive")
65+
self.assertEqual(len(figures_add), 2)
66+
67+
# Test multiplicative model (should raise ValueError for data with zero/negative values)
68+
with self.assertRaises(ValueError) as context:
69+
SeasonalDecompose(self.vm_dataset, seasonal_model="multiplicative")
70+
71+
# Verify the error message
72+
self.assertIn(
73+
"Multiplicative seasonality is not appropriate for zero and negative values",
74+
str(context.exception),
75+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import unittest
2+
import pandas as pd
3+
import numpy as np
4+
import validmind as vm
5+
from validmind.tests.data_validation.Skewness import Skewness
6+
7+
8+
class TestSkewness(unittest.TestCase):
9+
def setUp(self):
10+
# Set consistent size for all columns
11+
n_samples = 1000
12+
13+
# Create a dataset with known skewness
14+
# Normal distribution (low skewness)
15+
normal_data = np.random.normal(0, 1, n_samples)
16+
17+
# Right-skewed distribution (high positive skewness)
18+
skewed_data = np.random.exponential(2, n_samples)
19+
20+
# Non-numeric column
21+
categorical = ["A", "B", "C"] * (n_samples // 3)
22+
if (
23+
len(categorical) < n_samples
24+
): # Handle case where n_samples isn't divisible by 3
25+
categorical.extend(["A"] * (n_samples - len(categorical)))
26+
27+
df = pd.DataFrame(
28+
{"normal": normal_data, "skewed": skewed_data, "categorical": categorical}
29+
)
30+
31+
self.vm_dataset = vm.init_dataset(
32+
input_id="test_dataset",
33+
dataset=df,
34+
feature_columns=["normal", "skewed", "categorical"],
35+
__log=False,
36+
)
37+
38+
def test_skewness_threshold(self):
39+
# Test with default threshold (1)
40+
results, passed = Skewness(self.vm_dataset)
41+
42+
# Check return types
43+
self.assertIsInstance(results, dict)
44+
self.assertIn(passed, [True, False])
45+
46+
# Check results structure
47+
results_table = results["Skewness Results for Dataset"]
48+
self.assertIsInstance(results_table, list)
49+
50+
# Verify only numeric columns are included
51+
column_names = {row["Column"] for row in results_table}
52+
self.assertEqual(column_names, {"normal", "skewed"})
53+
54+
# Normal distribution should pass, skewed should fail
55+
for row in results_table:
56+
if row["Column"] == "normal":
57+
self.assertEqual(row["Pass/Fail"], "Pass")
58+
if row["Column"] == "skewed":
59+
self.assertEqual(row["Pass/Fail"], "Fail")
60+
61+
def test_custom_threshold(self):
62+
# Test with very high threshold (all should pass)
63+
results, passed = Skewness(self.vm_dataset, max_threshold=10)
64+
results_table = results["Skewness Results for Dataset"]
65+
66+
# All columns should pass with high threshold
67+
self.assertTrue(passed)
68+
self.assertTrue(all(row["Pass/Fail"] == "Pass" for row in results_table))
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
import pandas as pd
3+
import matplotlib.pyplot as plt
4+
5+
import validmind as vm
6+
7+
from validmind.errors import SkipTestError
8+
from validmind.tests.data_validation.SpreadPlot import SpreadPlot
9+
10+
11+
class TestSpreadPlot(unittest.TestCase):
12+
def setUp(self):
13+
# Create a sample time series dataset
14+
dates = pd.date_range(start="2023-01-01", periods=100, freq="D")
15+
df = pd.DataFrame(
16+
{"A": range(100), "B": [i * 2 for i in range(100)]}, index=dates
17+
)
18+
19+
self.vm_dataset = vm.init_dataset(
20+
input_id="test_dataset", dataset=df, feature_columns=["A", "B"], __log=False
21+
)
22+
23+
# Create a dataset without datetime index
24+
df_no_datetime = pd.DataFrame(
25+
{"A": range(100), "B": [i * 2 for i in range(100)]}
26+
)
27+
28+
self.vm_dataset_no_datetime = vm.init_dataset(
29+
input_id="test_dataset_no_datetime",
30+
dataset=df_no_datetime,
31+
feature_columns=["A", "B"],
32+
__log=False,
33+
)
34+
35+
def test_spread_plot(self):
36+
figures = SpreadPlot(self.vm_dataset)
37+
38+
# Check that we get the correct number of figures (one per feature pair)
39+
self.assertEqual(len(figures), 1) # Only one pair (A-B) for two features
40+
41+
# Check that outputs are matplotlib figures
42+
for fig in figures:
43+
self.assertIsInstance(fig, plt.Figure)
44+
45+
# Clean up
46+
plt.close("all")
47+
48+
def test_no_datetime_index(self):
49+
# Should raise an error for non-datetime index
50+
with self.assertRaises(SkipTestError):
51+
SpreadPlot(self.vm_dataset_no_datetime)
52+
53+
# Clean up
54+
plt.close("all")
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import unittest
2+
import pandas as pd
3+
import validmind as vm
4+
import plotly.graph_objs as go
5+
from validmind.errors import SkipTestError
6+
from validmind.tests.data_validation.TabularCategoricalBarPlots import (
7+
TabularCategoricalBarPlots,
8+
)
9+
10+
11+
class TestTabularCategoricalBarPlots(unittest.TestCase):
12+
def setUp(self):
13+
# Create a sample dataset with categorical and numerical columns
14+
df = pd.DataFrame(
15+
{
16+
"cat1": ["A", "B", "C", "A", "B"] * 20,
17+
"cat2": ["X", "Y", "X", "Y", "X"] * 20,
18+
"numeric": range(100),
19+
}
20+
)
21+
22+
self.vm_dataset = vm.init_dataset(
23+
input_id="test_dataset",
24+
dataset=df,
25+
feature_columns=["cat1", "cat2", "numeric"],
26+
__log=False,
27+
)
28+
29+
# Create dataset with no categorical columns
30+
df_no_cat = pd.DataFrame({"numeric1": range(100), "numeric2": range(100, 200)})
31+
32+
self.vm_dataset_no_cat = vm.init_dataset(
33+
input_id="test_dataset_no_cat",
34+
dataset=df_no_cat,
35+
feature_columns=["numeric1", "numeric2"],
36+
__log=False,
37+
)
38+
39+
def test_categorical_bar_plots(self):
40+
figures = TabularCategoricalBarPlots(self.vm_dataset)
41+
42+
# Check that we get the correct number of figures (one per categorical column)
43+
self.assertIsInstance(figures, tuple)
44+
self.assertEqual(len(figures), 2) # Should have 2 figures for cat1 and cat2
45+
46+
# Check that outputs are plotly figures
47+
for fig in figures:
48+
self.assertIsInstance(fig, go.Figure)
49+
50+
def test_no_categorical_columns(self):
51+
# Should raise SkipTestError when no categorical columns present
52+
with self.assertRaises(SkipTestError):
53+
TabularCategoricalBarPlots(self.vm_dataset_no_cat)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
import pandas as pd
3+
import validmind as vm
4+
import plotly.graph_objs as go
5+
from validmind.errors import SkipTestError
6+
from validmind.tests.data_validation.TabularDateTimeHistograms import (
7+
TabularDateTimeHistograms,
8+
)
9+
10+
11+
class TestTabularDateTimeHistograms(unittest.TestCase):
12+
def setUp(self):
13+
# Create a sample dataset with datetime index
14+
dates = pd.date_range(start="2023-01-01", periods=100, freq="D")
15+
df = pd.DataFrame({"A": range(100), "B": range(100, 200)}, index=dates)
16+
17+
self.vm_dataset = vm.init_dataset(
18+
input_id="test_dataset", dataset=df, feature_columns=["A", "B"], __log=False
19+
)
20+
21+
# Create dataset without datetime index
22+
df_no_datetime = pd.DataFrame({"A": range(100), "B": range(100, 200)})
23+
24+
self.vm_dataset_no_datetime = vm.init_dataset(
25+
input_id="test_dataset_no_datetime",
26+
dataset=df_no_datetime,
27+
feature_columns=["A", "B"],
28+
__log=False,
29+
)
30+
31+
def test_datetime_histograms(self):
32+
figure = TabularDateTimeHistograms(self.vm_dataset)
33+
34+
# Check that output is a plotly figure
35+
self.assertIsInstance(figure, go.Figure)
36+
37+
def test_no_datetime_index(self):
38+
# Should raise SkipTestError when no datetime index present
39+
with self.assertRaises(SkipTestError):
40+
TabularDateTimeHistograms(self.vm_dataset_no_datetime)

0 commit comments

Comments
 (0)