|
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import pandas as pd |
| 5 | +from trino.types import NamedRowTuple |
5 | 6 |
|
6 | 7 | from deepnote_toolkit.ocelots.constants import DEEPNOTE_INDEX_COLUMN |
7 | 8 | from deepnote_toolkit.ocelots.pandas.analyze import analyze_columns |
@@ -575,5 +576,89 @@ def test_mixed_column_types(self): |
575 | 576 | self.assertIsNotNone(col.stats) |
576 | 577 |
|
577 | 578 |
|
| 579 | +class TestAnalyzeColumnsWithTrinoTypes(unittest.TestCase): |
| 580 | + def test_analyze_columns_with_named_row_tuple(self): |
| 581 | + row1 = NamedRowTuple( |
| 582 | + values=[1, "Alice"], names=["id", "name"], types=["integer", "varchar"] |
| 583 | + ) |
| 584 | + row2 = NamedRowTuple( |
| 585 | + values=[2, "Bob"], names=["id", "name"], types=["integer", "varchar"] |
| 586 | + ) |
| 587 | + row3 = NamedRowTuple( |
| 588 | + values=[1, "Alice"], names=["id", "name"], types=["integer", "varchar"] |
| 589 | + ) |
| 590 | + |
| 591 | + np_array = np.empty(3, dtype=object) |
| 592 | + np_array[0] = row1 |
| 593 | + np_array[1] = row2 |
| 594 | + np_array[2] = row3 |
| 595 | + |
| 596 | + df = pd.DataFrame({"col1": np_array}) |
| 597 | + result = analyze_columns(df) |
| 598 | + |
| 599 | + self.assertEqual(len(result), 1) |
| 600 | + self.assertEqual(result[0].name, "col1") |
| 601 | + self.assertEqual(result[0].dtype, "object") |
| 602 | + self.assertIsNotNone(result[0].stats) |
| 603 | + self.assertIsNotNone(result[0].stats.categories) |
| 604 | + self.assertIsInstance(result[0].stats.categories, list) |
| 605 | + self.assertGreater(len(result[0].stats.categories), 0) |
| 606 | + for category in result[0].stats.categories: |
| 607 | + self.assertIn("name", category) |
| 608 | + self.assertIn("count", category) |
| 609 | + |
| 610 | + def test_analyze_columns_with_named_row_tuple_and_missing_values(self): |
| 611 | + row1 = NamedRowTuple( |
| 612 | + values=[1, "Alice"], names=["id", "name"], types=["integer", "varchar"] |
| 613 | + ) |
| 614 | + row2 = NamedRowTuple( |
| 615 | + values=[2, "Bob"], names=["id", "name"], types=["integer", "varchar"] |
| 616 | + ) |
| 617 | + |
| 618 | + np_array = np.empty(4, dtype=object) |
| 619 | + np_array[0] = row1 |
| 620 | + np_array[1] = row2 |
| 621 | + np_array[2] = None |
| 622 | + np_array[3] = row1 |
| 623 | + |
| 624 | + df = pd.DataFrame({"col1": np_array}) |
| 625 | + result = analyze_columns(df) |
| 626 | + |
| 627 | + self.assertEqual(len(result), 1) |
| 628 | + self.assertIsNotNone(result[0].stats) |
| 629 | + self.assertIsNotNone(result[0].stats.categories) |
| 630 | + |
| 631 | + category_names = [cat["name"] for cat in result[0].stats.categories] |
| 632 | + self.assertIn("Missing", category_names) |
| 633 | + |
| 634 | + missing_cat = next( |
| 635 | + cat for cat in result[0].stats.categories if cat["name"] == "Missing" |
| 636 | + ) |
| 637 | + self.assertEqual(missing_cat["count"], 1) |
| 638 | + |
| 639 | + def test_analyze_columns_with_many_named_row_tuples(self): |
| 640 | + np_array = np.empty(20, dtype=object) |
| 641 | + for i in range(10): |
| 642 | + row = NamedRowTuple( |
| 643 | + values=[i, f"User{i}"], |
| 644 | + names=["id", "name"], |
| 645 | + types=["integer", "varchar"], |
| 646 | + ) |
| 647 | + np_array[i * 2] = row |
| 648 | + np_array[i * 2 + 1] = row |
| 649 | + |
| 650 | + df = pd.DataFrame({"col1": np_array}) |
| 651 | + result = analyze_columns(df) |
| 652 | + |
| 653 | + self.assertEqual(len(result), 1) |
| 654 | + self.assertIsNotNone(result[0].stats) |
| 655 | + self.assertIsNotNone(result[0].stats.categories) |
| 656 | + self.assertGreaterEqual(len(result[0].stats.categories), 1) |
| 657 | + self.assertLessEqual(len(result[0].stats.categories), 3) |
| 658 | + |
| 659 | + has_others = any("others" in cat["name"] for cat in result[0].stats.categories) |
| 660 | + self.assertTrue(has_others) |
| 661 | + |
| 662 | + |
578 | 663 | if __name__ == "__main__": |
579 | 664 | unittest.main() |
0 commit comments