-
Notifications
You must be signed in to change notification settings - Fork 3
Apply changes from juan5508/sc-3259/Create-Metric-V2-class #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
6dd9218
Apply changes from juan5508/sc-3259/Create-Metric-V2-class
juanmleng 17eb5e9
Merge branch 'main'
juanmleng cbe7cec
Add unit metrics notebook
juanmleng 65d0302
Update unit metrics to use new inputs
juanmleng 2716d02
Add unit_metrics namespace
juanmleng fbda9d3
Improve dataset serialization
juanmleng 8a69786
Update notebooks/how_to/run_unit_metrics.ipynb
juanmleng 35b8f41
Update notebooks/how_to/run_unit_metrics.ipynb
juanmleng 6b6fbe9
Update notebooks/how_to/run_unit_metrics.ipynb
juanmleng e6963f9
Move notebook to code_sharing folder
juanmleng f249b49
Delete obsolete define_metrics_v2.ipynb
juanmleng 9240ddf
model is a required input as well
juanmleng c55134b
Remove unused imports
juanmleng 77eafb4
remove get_metric_value
juanmleng ffd883d
Update summary method
juanmleng a9854bd
rename metric_v2 to unit_metric
juanmleng cd125c0
undo non-desired changes to dataset
juanmleng 3f971c3
Remove unused properties
juanmleng e640336
Remove result metadata
juanmleng 90a8740
Update notebook
juanmleng 4df599e
Remove type and scope
juanmleng 7014fcc
Fix lint
juanmleng cbee808
Remove get_y_pred function
juanmleng 3f1d5ec
Update hash method
juanmleng f41e750
Refactor sklearn unit metrics
juanmleng e32df39
Merge branch 'main'
juanmleng 99cc64c
Fix dependencies
juanmleng b645c35
Clean notebook
juanmleng 6843608
Add copyright
juanmleng 28d2af4
Remove audit code in notebook
juanmleng 35f7776
Simplify notebook
juanmleng f00b2b1
Clear notebook output
juanmleng 239f8f8
Remove unused keys
juanmleng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
juanmleng marked this conversation as resolved.
Show resolved
Hide resolved
|
Large diffs are not rendered by default.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
# Copyright © 2023-2024 ValidMind Inc. All rights reserved. | ||
# See the LICENSE file in the root of this repository for details. | ||
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial | ||
|
||
import numpy as np | ||
import json | ||
import hashlib | ||
import importlib | ||
|
||
from validmind.vm_models import TestInput | ||
|
||
from ..utils import get_model_info | ||
|
||
|
||
unit_metric_results_cache = {} | ||
|
||
|
||
def _serialize_params(params): | ||
""" | ||
Serialize the parameters to a unique hash, handling None values. | ||
This function serializes the parameters dictionary to a JSON string, | ||
then creates a SHA-256 hash of the string to ensure a unique identifier | ||
for the parameters. If params is None, a default hash is returned. | ||
|
||
Args: | ||
params (dict or None): The parameters to be serialized. | ||
|
||
Returns: | ||
str: A SHA-256 hash of the JSON string representation of the params, | ||
or a default hash if params is None. | ||
""" | ||
if params is None: | ||
# Handle None by returning a hash of an empty dictionary or a predefined value | ||
params_json = json.dumps({}) | ||
else: | ||
params_json = json.dumps(params, sort_keys=True) | ||
|
||
hash_object = hashlib.sha256(params_json.encode()) | ||
return hash_object.hexdigest() | ||
|
||
|
||
def _serialize_model(model): | ||
""" | ||
Generate a SHA-256 hash for a scikit-learn model based on its type and parameters. | ||
|
||
Args: | ||
model VMModel: The model to be serialized. | ||
|
||
Returns: | ||
str: A SHA-256 hash of the model's description. | ||
""" | ||
|
||
model_info = get_model_info(model) | ||
|
||
model_json = json.dumps(model_info, sort_keys=True) | ||
|
||
# Create a SHA-256 hash of the JSON string | ||
hash_object = hashlib.sha256(model_json.encode()) | ||
return hash_object.hexdigest() | ||
|
||
|
||
def _serialize_dataset(dataset, model_id): | ||
""" | ||
Serialize the description of the dataset input to a unique hash. | ||
|
||
This function generates a hash based on the dataset's structure, including | ||
the target and feature columns, the prediction column associated with a specific model ID, | ||
and directly incorporates the model ID and prediction column name to ensure uniqueness. | ||
|
||
Args: | ||
dataset: The dataset object, which should have properties like _df (pandas DataFrame), | ||
target_column (string), feature_columns (list of strings), and _extra_columns (dict). | ||
model_id (str): The ID of the model associated with the prediction column. | ||
|
||
Returns: | ||
str: A SHA-256 hash representing the dataset. | ||
|
||
Note: | ||
Including the model ID and prediction column name in the hash calculation ensures uniqueness, | ||
especially in cases where the predictions are sparse or the dataset has not significantly changed. | ||
This approach guarantees that the hash will distinguish between model-generated predictions | ||
and pre-computed prediction columns, addressing potential hash collisions. | ||
""" | ||
|
||
# Access the prediction column for the given model ID from the dataset's extra columns | ||
prediction_column_name = dataset._extra_columns["prediction_columns"][model_id] | ||
|
||
# Include model ID and prediction column name directly in the hash calculation | ||
model_and_prediction_info = f"{model_id}_{prediction_column_name}".encode() | ||
|
||
# Start with target and feature columns, and include the prediction column | ||
columns = ( | ||
[dataset._target_column] + dataset._feature_columns + [prediction_column_name] | ||
) | ||
|
||
# Use _fast_hash function and include model_and_prediction_info in the hash calculation | ||
hash_digest = _fast_hash( | ||
dataset._df[columns], model_and_prediction_info=model_and_prediction_info | ||
) | ||
|
||
return hash_digest | ||
|
||
|
||
def _fast_hash(df, sample_size=1000, model_and_prediction_info=None): | ||
""" | ||
Generates a hash for a DataFrame by sampling and combining its size, content, | ||
and optionally model and prediction information. | ||
|
||
Args: | ||
df (pd.DataFrame): The DataFrame to hash. | ||
sample_size (int): The maximum number of rows to include in the sample. | ||
model_and_prediction_info (bytes, optional): Additional information to include in the hash. | ||
|
||
Returns: | ||
str: A SHA-256 hash of the DataFrame's sample and additional information. | ||
""" | ||
# Convert the number of rows to bytes and include it in the hash calculation | ||
rows_bytes = str(len(df)).encode() | ||
|
||
# Sample rows if DataFrame is larger than sample_size, ensuring reproducibility | ||
if len(df) > sample_size: | ||
df_sample = df.sample(n=sample_size, random_state=42) | ||
else: | ||
df_sample = df | ||
|
||
# Convert the sampled DataFrame to a byte array. np.asarray ensures compatibility with various DataFrame contents. | ||
byte_array = np.asarray(df_sample).data.tobytes() | ||
|
||
# Initialize the hash object and update it with the row count, data bytes, and additional info | ||
hash_obj = hashlib.sha256( | ||
rows_bytes + byte_array + (model_and_prediction_info or b"") | ||
) | ||
|
||
return hash_obj.hexdigest() | ||
|
||
|
||
def _get_metric_class(metric_id): | ||
"""Get the metric class by metric_id | ||
|
||
This function will load the metric class by metric_id. | ||
|
||
Args: | ||
metric_id (str): The full metric id (e.g. 'validmind.vm_models.test.v2.model_validation.sklearn.F1') | ||
|
||
Returns: | ||
Metric: The metric class | ||
""" | ||
|
||
metric_module = importlib.import_module(f"{metric_id}") | ||
|
||
class_name = metric_id.split(".")[-1] | ||
|
||
# Access the class within the F1 module | ||
metric_class = getattr(metric_module, class_name) | ||
|
||
return metric_class | ||
|
||
|
||
def get_input_type(input_obj): | ||
""" | ||
Determines whether the input object is a 'dataset' or 'model' based on its class module path. | ||
|
||
Args: | ||
input_obj: The object to type check. | ||
|
||
Returns: | ||
str: 'dataset' or 'model' depending on the object's module, or raises ValueError. | ||
""" | ||
# Obtain the class object of input_obj (for clarity and debugging) | ||
class_obj = input_obj.__class__ | ||
|
||
# Obtain the module name as a string from the class object | ||
class_module = class_obj.__module__ | ||
|
||
if "validmind.vm_models.dataset" in class_module: | ||
return "dataset" | ||
elif "validmind.models" in class_module: | ||
return "model" | ||
else: | ||
raise ValueError("Input must be of type validmind Dataset or Model") | ||
|
||
|
||
def get_metric_cache_key(metric_id, params, inputs): | ||
cache_elements = [metric_id] | ||
|
||
# Serialize params if not None | ||
serialized_params = _serialize_params(params) if params else "None" | ||
cache_elements.append(serialized_params) | ||
|
||
# Check if 'inputs' is a dictionary | ||
if not isinstance(inputs, dict): | ||
raise TypeError("Expected 'inputs' to be a dictionary.") | ||
|
||
# Check for 'model' and 'dataset' keys in 'inputs' | ||
if "model" not in inputs or "dataset" not in inputs: | ||
raise ValueError("Missing 'model' or 'dataset' in 'inputs'.") | ||
|
||
dataset = inputs["dataset"] | ||
model = inputs["model"] | ||
model_id = model.input_id | ||
|
||
cache_elements.append(_serialize_dataset(dataset, model_id)) | ||
|
||
cache_elements.append(_serialize_model(model)) | ||
|
||
# Combine elements to form the cache key | ||
combined_elements = "_".join(cache_elements) | ||
key = hashlib.sha256(combined_elements.encode()).hexdigest() | ||
return key | ||
|
||
|
||
def run_metric(metric_id=None, inputs=None, params=None): | ||
"""Run a single metric | ||
|
||
This function provides a high level interface for running a single metric. A metric | ||
is a single test that calculates a value based on the input data. | ||
|
||
Args: | ||
metric_id (str): The metric name (e.g. 'F1') | ||
params (dict): A dictionary of the metric parameters | ||
|
||
Returns: | ||
MetricResult: The metric result object | ||
""" | ||
cache_key = get_metric_cache_key(metric_id, params, inputs) | ||
|
||
# Check if the metric value already exists in the global variable | ||
if cache_key in unit_metric_results_cache: | ||
return unit_metric_results_cache[cache_key] | ||
|
||
# Load the metric class by metric_id | ||
metric_class = _get_metric_class(metric_id) | ||
|
||
# Initialize the metric | ||
metric = metric_class(test_id=metric_id, inputs=TestInput(inputs), params=params) | ||
|
||
# Run the metric | ||
result = metric.run() | ||
|
||
cache_key = get_metric_cache_key(metric_id, params, inputs) | ||
|
||
unit_metric_results_cache[cache_key] = result | ||
|
||
return result |
juanmleng marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright © 2023-2024 ValidMind Inc. All rights reserved. | ||
# See the LICENSE file in the root of this repository for details. | ||
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial | ||
|
||
from dataclasses import dataclass | ||
|
||
from sklearn.metrics import accuracy_score | ||
|
||
from validmind.vm_models import UnitMetric | ||
|
||
|
||
@dataclass | ||
class Accuracy(UnitMetric): | ||
|
||
def run(self): | ||
y_true = self.inputs.dataset.y | ||
y_pred = self.inputs.dataset.y_pred(model_id=self.inputs.model.input_id) | ||
|
||
value = accuracy_score(y_true, y_pred, **self.params) | ||
|
||
return self.cache_results(metric_value=value) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright © 2023-2024 ValidMind Inc. All rights reserved. | ||
# See the LICENSE file in the root of this repository for details. | ||
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial | ||
|
||
from dataclasses import dataclass | ||
|
||
from sklearn.metrics import f1_score | ||
|
||
from validmind.vm_models import UnitMetric | ||
|
||
|
||
@dataclass | ||
class F1(UnitMetric): | ||
|
||
def run(self): | ||
y_true = self.inputs.dataset.y | ||
y_pred = self.inputs.dataset.y_pred(model_id=self.inputs.model.input_id) | ||
|
||
value = f1_score(y_true, y_pred, **self.params) | ||
|
||
return self.cache_results( | ||
metric_value=value, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.