Skip to content

Commit

Permalink
review changes (Azure#34480)
Browse files Browse the repository at this point in the history
  • Loading branch information
nemanjarajic authored Feb 28, 2024
1 parent 404b60e commit b2c66cb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
self.instance_type = instance_type

def _to_rest_object(self) -> MonitorServerlessSparkCompute:
self.validate()
self._validate()
return MonitorServerlessSparkCompute(
runtime_version=self.runtime_version,
instance_type=self.instance_type,
Expand All @@ -42,7 +42,7 @@ def _from_rest_object(cls, obj: MonitorServerlessSparkCompute) -> "ServerlessSpa
instance_type=obj.instance_type,
)

def validate(self) -> None:
def _validate(self) -> None:
if self.runtime_version != "3.3":
msg = "Compute runtime version must be 3.3"
err = ValidationException(
Expand Down
19 changes: 4 additions & 15 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_monitoring/thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ def _get_default_thresholds(cls) -> "NumericalDriftMetrics":
def defaults(cls) -> "NumericalDriftMetrics":
return cls._get_default_thresholds()

def get_name_and_threshold(self) -> Tuple:
return self._find_name_and_threshold()


class CategoricalDriftMetrics(RestTranslatableMixin):
"""Categorical Drift Metrics
Expand Down Expand Up @@ -204,15 +201,15 @@ def __init__(
def _to_rest_object(self) -> DataDriftMetricThresholdBase:
thresholds = []
if self.numerical:
num_metric_name, num_threshold = self.numerical.get_name_and_threshold()
num_metric_name, num_threshold = self.numerical._find_name_and_threshold()
thresholds.append(
NumericalDataDriftMetricThreshold(
metric=snake_to_camel(num_metric_name),
threshold=num_threshold,
)
)
if self.categorical:
cat_metric_name, cat_threshold = self.categorical.get_name_and_threshold()
cat_metric_name, cat_threshold = self.categorical._find_name_and_threshold()
thresholds.append(
CategoricalDataDriftMetricThreshold(
metric=snake_to_camel(cat_metric_name),
Expand Down Expand Up @@ -279,15 +276,15 @@ def __init__(
def _to_rest_object(self) -> PredictionDriftMetricThresholdBase:
thresholds = []
if self.numerical:
num_metric_name, num_threshold = self.numerical.get_name_and_threshold()
num_metric_name, num_threshold = self.numerical._find_name_and_threshold()
thresholds.append(
NumericalPredictionDriftMetricThreshold(
metric=snake_to_camel(num_metric_name),
threshold=num_threshold,
)
)
if self.categorical:
cat_metric_name, cat_threshold = self.categorical.get_name_and_threshold()
cat_metric_name, cat_threshold = self.categorical._find_name_and_threshold()
thresholds.append(
CategoricalPredictionDriftMetricThreshold(
metric=snake_to_camel(cat_metric_name),
Expand Down Expand Up @@ -404,10 +401,6 @@ def _get_default_thresholds(cls) -> "DataQualityMetricsNumerical":
out_of_bounds_rate=0.0,
)

@classmethod
def defaults(cls) -> "DataQualityMetricsNumerical":
return cls._get_default_thresholds()


class DataQualityMetricsCategorical(RestTranslatableMixin):
"""Data Quality Categorical Metrics
Expand Down Expand Up @@ -480,10 +473,6 @@ def _get_default_thresholds(cls) -> "DataQualityMetricsCategorical":
out_of_bounds_rate=0.0,
)

@classmethod
def defaults(cls) -> "DataQualityMetricsCategorical":
return cls._get_default_thresholds()


class DataQualityMetricThreshold(MetricThreshold):
"""Data quality metric threshold
Expand Down

0 comments on commit b2c66cb

Please sign in to comment.