Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix bug with regression type checking
  • Loading branch information
stuartquin committed Sep 16, 2022
commit 8df703e6ef7dddab99b39e285e4144a876774c4c
16 changes: 9 additions & 7 deletions dataqa/infer_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Dict, Optional, Union, List
from typing import Any, Dict, Optional, Tuple, Union, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -211,7 +211,7 @@ def is_subset(list1: List[Any], list2: List[Any]) -> bool:
def check_prediction_columns(
column_mapping: ColumnMapping,
column_to_categories: Dict[str, List[Union[str, np.number]]],
) -> dict:
) -> Dict:
schema_dict = dict(
(column, {"type": ColumnType.CATEGORICAL})
for column in column_mapping.categorical_columns
Expand Down Expand Up @@ -279,8 +279,10 @@ def check_prediction_columns(
)

if task == PredictionTask.REGRESSION:
if schema_dict[prediction_column] != ColumnType.NUMERICAL:
raise Exception(f"Regression tasks only valid with numerical columns.")
if schema_dict[prediction_column]["type"] != ColumnType.NUMERICAL:
raise Exception(
f"Regression tasks only valid with numerical columns {prediction_column}"
)

if task == PredictionTask.CLASSIFICATION:
if not schema_dict[prediction_column]["type"] in [
Expand All @@ -299,7 +301,7 @@ def format_validated_schema(
schema_dict: dict,
prediction_columns: List[PredictionColumn],
column_to_categories: Dict[str, List[Union[str, np.number]]],
) -> dict:
) -> List:
new_schema = []
prediction_columns_dict = {
column.prediction_column: column for column in prediction_columns
Expand All @@ -320,14 +322,14 @@ def format_validated_schema(
column_row["ground_truth"] = prediction_columns_dict[
column
].ground_truth_column
new_schema.append(column_row)
new_schema.append(column_row)

return new_schema


def validate_schema(
df: pd.DataFrame, column_mapping: ColumnMapping
) -> [ColumnMapping, pd.DataFrame]:
) -> Tuple[ColumnMapping, pd.DataFrame]:
categorical_columns = column_mapping.categorical_columns or []
numerical_columns = column_mapping.numerical_columns or []
text_columns = column_mapping.text_columns or []
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dataqa"
version = "2.0.3"
version = "2.0.4"
description = "Python Client library for DataQA"
authors = ["Maria Mestre <maria@dataqa.ai>","Stuart Quin <stuart@dataqa.ai>"]
readme = "README.md"
Expand Down