Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gsq signal #31998

Merged
merged 7 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
PredictionDriftSignalSchema,
FeatureAttributionDriftSignalSchema,
CustomMonitoringSignalSchema,
GenerationSafetyQualitySchema,
)
from azure.ai.ml._schema.monitoring.alert_notification import AlertNotificationSchema
from azure.ai.ml._schema.core.fields import NestedField, UnionField, StringTransformedEnum
Expand All @@ -33,6 +34,7 @@ class MonitorDefinitionSchema(metaclass=PatchedSchemaMeta):
NestedField(PredictionDriftSignalSchema),
NestedField(FeatureAttributionDriftSignalSchema),
NestedField(CustomMonitoringSignalSchema),
NestedField(GenerationSafetyQualitySchema),
]
),
)
Expand Down
38 changes: 38 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/monitoring/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FeatureAttributionDriftMetricThresholdSchema,
ModelPerformanceMetricThresholdSchema,
CustomMonitoringMetricThresholdSchema,
GenerationSafetyQualityMetricThresholdSchema,
)


Expand Down Expand Up @@ -275,3 +276,40 @@ def make(self, data, **kwargs):

data.pop("type", None)
return CustomMonitoringSignal(**data)


class LlmRequestResponseDataSchema(metaclass=PatchedSchemaMeta):
input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)])
data_column_names = fields.Dict()
data_window_size = fields.Str()

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.signals import LlmRequestResponseData

return LlmRequestResponseData(**data)


class GenerationSafetyQualitySchema(metaclass=PatchedSchemaMeta):
type = StringTransformedEnum(allowed_values=MonitorSignalType.GENERATION_SAFETY_QUALITY, required=True)
production_data = fields.List(NestedField(LlmRequestResponseDataSchema))
workspace_connection_id = fields.Str()
metric_thresholds = NestedField(GenerationSafetyQualityMetricThresholdSchema)
alert_enabled = fields.Bool()
properties = fields.Dict()
sampling_rate = fields.Float()

@pre_dump
def predump(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal

if not isinstance(data, GenerationSafetyQualitySignal):
raise ValidationError("Cannot dump non-GenerationSafetyQuality object into GenerationSafetyQuality")
return data

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal

data.pop("type", None)
return GenerationSafetyQualitySignal(**data)
29 changes: 29 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/monitoring/thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,32 @@ def make(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.thresholds import CustomMonitoringMetricThreshold

return CustomMonitoringMetricThreshold(**data)


class GenerationSafetyQualityMetricThresholdSchema(metaclass=PatchedSchemaMeta):
groundedness = fields.Dict(
keys=StringTransformedEnum(allowed_values=["acceptable_groundedness_score_per_instance","aggregated_groundedness_pass_rate"]),
values=fields.Number()
)
relevance = fields.Dict(
keys=StringTransformedEnum(allowed_values=["acceptable_relevance_score_per_instance","aggregated_relevance_pass_rate"]),
values=fields.Number()
)
coherence = fields.Dict(
keys=StringTransformedEnum(allowed_values=["acceptable_coherence_score_per_instance","aggregated_coherence_pass_rate"]),
values=fields.Number()
)
fluency = fields.Dict(
keys=StringTransformedEnum(allowed_values=["acceptable_fluency_score_per_instance","aggregated_fluency_pass_rate"]),
values=fields.Number()
)
similary = fields.Dict(
keys=StringTransformedEnum(allowed_values=["acceptable_similarity_score_per_instance","aggregated_similarity_pass_rate"]),
values=fields.Number()
)

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.thresholds import GenerationSafetyQualityMonitoringMetricThreshold

return GenerationSafetyQualityMonitoringMetricThreshold(**data)
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MonitorSignalType(str, Enum, metaclass=CaseInsensitiveEnumMeta):
MODEL_PERFORMANCE = "model_performance"
FEATURE_ATTRIBUTION_DRIFT = "feature_attribution_drift"
CUSTOM = "custom"
GENERATION_SAFETY_QUALITY = "generation_safety_quality"


@experimental
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
FeatureAttributionDriftSignal,
MonitoringSignal,
PredictionDriftSignal,
GenerationSafetyQualitySignal,
)
from azure.ai.ml.entities._monitoring.target import MonitoringTarget
from azure.ai.ml.entities._monitoring.compute import ServerlessSparkCompute
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
PredictionDriftSignal,
FeatureAttributionDriftSignal,
CustomMonitoringSignal,
GenerationSafetyQualitySignal,
],
] = None,
alert_notification: Optional[Union[Literal[AZMONITORING], AlertNotification]] = None,
Expand Down
84 changes: 84 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_monitoring/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from azure.ai.ml._restclient.v2023_06_01_preview.models import (
DataQualityMonitoringSignal as RestMonitoringDataQualitySignal,
)
from azure.ai.ml._restclient.v2023_06_01_preview.models import (
GenerationSafetyQualityMonitoringSignal as RestGenerationSafetyQualityMonitoringSignal,
)
from azure.ai.ml._restclient.v2023_06_01_preview.models import (
FeatureAttributionDriftMonitoringSignal as RestFeatureAttributionDriftMonitoringSignal,
)
Expand Down Expand Up @@ -59,6 +62,7 @@
MetricThreshold,
ModelPerformanceMetricThreshold,
PredictionDriftMetricThreshold,
GenerationSafetyQualityMonitoringMetricThreshold,
)


Expand Down Expand Up @@ -339,6 +343,8 @@ def _from_rest_object( # pylint: disable=too-many-return-statements
return FeatureAttributionDriftSignal._from_rest_object(obj)
if obj.signal_type == MonitoringSignalType.CUSTOM:
return CustomMonitoringSignal._from_rest_object(obj)
if obj.signal_type == MonitoringSignalType.GENERATION_SAFETY_QUALITY:
return GenerationSafetyQualitySignal._from_rest_object(obj)

return None

Expand Down Expand Up @@ -933,6 +939,84 @@ def _from_rest_object(cls, obj: RestCustomMonitoringSignal) -> "CustomMonitoring
)


@experimental
class LlmRequestResponseData(RestTranslatableMixin):
def __init__(
self,
*,
input_data: Input,
data_column_names: Dict[str, str] = None,
data_window_size: str = None,
):
self.input_data = input_data
self.data_column_names = data_column_names
self.data_window_size = data_window_size

def _to_rest_object(self, **kwargs) -> RestMonitoringInputData:
if self.data_window_size is None:
self.data_window_size = kwargs.get("default_data_window_size")
return TrailingInputData(
target_columns=self.data_column_names,
job_type=self.input_data.type,
uri=self.input_data.path,
window_size=self.data_window_size,
window_offset=self.data_window_size,
)._to_rest_object()

@classmethod
def _from_rest_object(cls, obj: RestMonitoringInputData) -> "LlmRequestResponseData":
return cls(
input_data=Input(
path=obj.uri,
type=obj.job_input_type,
),
data_column_names=obj.columns,
data_window_size=isodate.duration_isoformat(obj.window_size),
)

@experimental
class GenerationSafetyQualitySignal(RestTranslatableMixin):
def __init__(
self,
*,
production_data: List[LlmRequestResponseData],
workspace_connection_id: str,
metric_thresholds: GenerationSafetyQualityMonitoringMetricThreshold,
alert_enabled: bool = True,
properties: Optional[Dict[str, str]] = None,
sampling_rate: Optional[float] = None,
):
self.type = MonitorSignalType.GENERATION_SAFETY_QUALITY
self.production_data = production_data
self.workspace_connection_id = workspace_connection_id
self.metric_thresholds = metric_thresholds
self.alert_enabled = alert_enabled
self.properties = properties
self.sampling_rate = sampling_rate

def _to_rest_object(self, **kwargs) -> RestGenerationSafetyQualityMonitoringSignal:
return RestGenerationSafetyQualityMonitoringSignal(
production_data=[data._to_rest_object() for data in self.production_data],
workspace_connection_id=self.workspace_connection_id,
metric_thresholds=self.metric_thresholds._to_rest_object(),
mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
properties=self.properties,
sampling_rate=self.sampling_rate,
)

@classmethod
def _from_rest_object(cls, obj: RestGenerationSafetyQualityMonitoringSignal) -> "GenerationSafetyQualitySignal":
return cls(
production_data=[LlmRequestResponseData._from_rest_object(data) for data in obj.production_data],
workspace_connection_id=obj.workspace_connection_id,
metric_thresholds=GenerationSafetyQualityMonitoringMetricThreshold._from_rest_object(obj.metric_thresholds),
alert_enabled=False
if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
else MonitoringNotificationMode.ENABLED,
properties=obj.properties,
sampling_rate=obj.sampling_rate,
)

def _from_rest_features(
obj: RestMonitoringFeatureFilterBase,
) -> Optional[Union[List[str], MonitorFeatureFilter, Literal[ALL_FEATURES]]]:
Expand Down
145 changes: 144 additions & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/entities/_monitoring/thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

from typing_extensions import Literal

Expand All @@ -25,6 +25,7 @@
RegressionModelPerformanceMetricThreshold,
RegressionModelPerformanceMetric,
CustomMetricThreshold,
GenerationSafetyQualityMetricThreshold,
)
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
from azure.ai.ml._utils._experimental import experimental
Expand Down Expand Up @@ -663,3 +664,145 @@ def _to_rest_object(self) -> CustomMetricThreshold:
@classmethod
def _from_rest_object(cls, obj: CustomMetricThreshold) -> "CustomMonitoringMetricThreshold":
return cls(metric_name=obj.metric, threshold=obj.threshold.value if obj.threshold else None)


class GenerationSafetyQualityMonitoringMetricThreshold(RestTranslatableMixin):
"""Generation safety quality metric threshold

:param metric_name: The metric to calculate
:type metric_name: str
:param threshold: The threshold value. If None, a default value will be set
depending on the selected metric.
:type threshold: float
"""

def __init__(
self,
*,
groundedness: Dict[str, float] = None,
relevence: Dict[str, float] = None,
coherence: Dict[str, float] = None,
fluency: Dict[str, float] = None,
similarity: Dict[str, float] = None,
):
self.groundedness = groundedness
self.relevence = relevence
self.coherence = coherence
self.fluency = fluency
self.similarity = similarity

def _to_rest_object(self) -> GenerationSafetyQualityMetricThreshold:
metric_thresholds = []
if self.groundedness:
if "acceptable_groundedness_score_per_instance" in self.groundedness:
threshold = MonitoringThreshold(value=self.groundedness["acceptable_groundedness_score_per_instance"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="acceptable_groundedness_score_per_instance", threshold=threshold
)
)
if "aggregated_groundedness_pass_rate" in self.groundedness:
threshold = MonitoringThreshold(value=self.groundedness["aggregated_groundedness_pass_rate"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="aggregated_groundedness_pass_rate", threshold=threshold
)
)
if self.relevence:
if "acceptable_relevance_score_per_instance" in self.relevence:
threshold = MonitoringThreshold(value=self.relevence["acceptable_relevance_score_per_instance"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="acceptable_relevance_score_per_instance", threshold=threshold
)
)
if "aggregated_relevance_pass_rate" in self.relevence:
threshold = MonitoringThreshold(value=self.relevence["aggregated_relevance_pass_rate"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="aggregated_relevance_pass_rate", threshold=threshold
)
)
if self.coherence:
if "acceptable_coherence_score_per_instance" in self.coherence:
threshold = MonitoringThreshold(value=self.coherence["acceptable_coherence_score_per_instance"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="acceptable_coherence_score_per_instance", threshold=threshold
)
)
if "aggregated_coherence_pass_rate" in self.coherence:
threshold = MonitoringThreshold(value=self.coherence["aggregated_coherence_pass_rate"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="aggregated_coherence_pass_rate", threshold=threshold
)
)
if self.fluency:
if "acceptable_fluency_score_per_instance" in self.fluency:
threshold = MonitoringThreshold(value=self.fluency["acceptable_fluency_score_per_instance"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="acceptable_fluency_score_per_instance", threshold=threshold
)
)
if "aggregated_fluency_pass_rate" in self.fluency:
threshold = MonitoringThreshold(value=self.fluency["aggregated_fluency_pass_rate"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="aggregated_fluency_pass_rate", threshold=threshold
)
)
if self.similarity:
if "acceptable_similarity_score_per_instance" in self.similarity:
threshold = MonitoringThreshold(value=self.similarity["acceptable_similarity_score_per_instance"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="acceptable_similarity_score_per_instance", threshold=threshold
)
)
if "aggregated_similarity_pass_rate" in self.similarity:
threshold = MonitoringThreshold(value=self.similarity["aggregated_similarity_pass_rate"])
metric_thresholds.append(
GenerationSafetyQualityMetricThreshold(
metric="aggregated_similarity_pass_rate", threshold=threshold
)
)
return metric_thresholds

@classmethod
def _from_rest_object(cls, obj: GenerationSafetyQualityMetricThreshold) -> "GenerationSafetyQualityMonitoringMetricThreshold":
groundedness = {}
relevence = {}
coherence = {}
fluency = {}
similarity = {}

for threshold in obj:
if threshold.metric == "acceptable_groundedness_score_per_instance":
groundedness["acceptable_groundedness_score_per_instance"] = threshold.threshold.value
if threshold.metric == "aggregated_groundedness_pass_rate":
groundedness["aggregated_groundedness_pass_rate"] = threshold.threshold.value
if threshold.metric == "acceptable_relevance_score_per_instance":
relevence["acceptable_relevance_score_per_instance"] = threshold.threshold.value
if threshold.metric == "aggregated_relevance_pass_rate":
relevence["aggregated_relevance_pass_rate"] = threshold.threshold.value
if threshold.metric == "acceptable_coherence_score_per_instance":
relevence["acceptable_coherence_score_per_instance"] = threshold.threshold.value
if threshold.metric == "aggregated_coherence_pass_rate":
relevence["aggregated_coherence_pass_rate"] = threshold.threshold.value
if threshold.metric == "acceptable_fluency_score_per_instance":
relevence["acceptable_fluency_score_per_instance"] = threshold.threshold.value
if threshold.metric == "aggregated_fluency_pass_rate":
relevence["aggregated_fluency_pass_rate"] = threshold.threshold.value
if threshold.metric == "acceptable_similarity_score_per_instance":
relevence["acceptable_similarity_score_per_instance"] = threshold.threshold.value
if threshold.metric == "aggregated_similarity_pass_rate":
relevence["aggregated_similarity_pass_rate"] = threshold.threshold.value
return cls(
groundedness=groundedness if groundedness else None,
relevence=relevence if relevence else None,
coherence=coherence if coherence else None,
fluency=fluency if fluency else None,
similarity=similarity if similarity else None,
)
Loading