Skip to content

Added support for must_compute flag #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions aimon/decorators/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class Detect:
The name of the application to use when publish is True.
model_name : str, optional
The name of the model to use when publish is True.
must_compute : str, optional
Indicates the computation strategy. Must be either 'all_or_none' or 'ignore_failures'. Default is 'all_or_none'.

Example:
--------
Expand Down Expand Up @@ -133,7 +135,7 @@ class Detect:
"""
DEFAULT_CONFIG = {'hallucination': {'detector_name': 'default'}}

def __init__(self, values_returned, api_key=None, config=None, async_mode=False, publish=False, application_name=None, model_name=None):
def __init__(self, values_returned, api_key=None, config=None, async_mode=False, publish=False, application_name=None, model_name=None, must_compute='all_or_none'):
"""
:param values_returned: A list of values in the order returned by the decorated function
Acceptable values are 'generated_text', 'context', 'user_query', 'instructions'
Expand All @@ -144,6 +146,7 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False,
:param publish: Boolean, if True, the payload will be published to AIMon and can be viewed on the AIMon UI. Default is False.
:param application_name: The name of the application to use when publish is True
:param model_name: The name of the model to use when publish is True
:param must_compute: String, indicates the computation strategy. Must be either 'all_or_none' or 'ignore_failures'. Default is 'all_or_none'.
"""
api_key = os.getenv('AIMON_API_KEY') if not api_key else api_key
if api_key is None:
Expand All @@ -163,8 +166,15 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False,
if model_name is None:
raise ValueError("Model name must be provided if publish is True")

# Validate must_compute parameter
if not isinstance(must_compute, str):
raise ValueError("`must_compute` must be a string value")
if must_compute not in ['all_or_none', 'ignore_failures']:
raise ValueError("`must_compute` must be either 'all_or_none' or 'ignore_failures'")
self.must_compute = must_compute

self.application_name = application_name
self.model_name = model_name
self.model_name = model_name

def __call__(self, func):
@wraps(func)
Expand All @@ -181,6 +191,7 @@ def wrapper(*args, **kwargs):
aimon_payload['config'] = self.config
aimon_payload['publish'] = self.publish
aimon_payload['async_mode'] = self.async_mode
aimon_payload['must_compute'] = self.must_compute

# Include application_name and model_name if publishing
if self.publish:
Expand Down
5 changes: 3 additions & 2 deletions aimon/types/inference_detect_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class BodyConfigInstructionAdherence(TypedDict, total=False):
class BodyConfigToxicity(TypedDict, total=False):
detector_name: Literal["default"]


class BodyConfig(TypedDict, total=False):
completeness: BodyConfigCompleteness

Expand All @@ -61,7 +60,6 @@ class BodyConfig(TypedDict, total=False):

toxicity: BodyConfigToxicity


class Body(TypedDict, total=False):
context: Required[Union[List[str], str]]
"""Context as an array of strings or a single string"""
Expand All @@ -81,6 +79,9 @@ class Body(TypedDict, total=False):
model_name: str
"""The model name for publishing metrics for an application."""

must_compute: str
"""Indicates the computation strategy. Must be either 'all_or_none' or 'ignore_failures'."""

publish: bool
"""Indicates whether to publish metrics."""

Expand Down
155 changes: 155 additions & 0 deletions tests/test_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,3 +824,158 @@ def test_evaluate_with_new_model(self):
import os
if os.path.exists(dataset_path):
os.remove(dataset_path)

def test_must_compute_validation(self):
"""Test that the must_compute parameter is properly validated."""
print("\n=== Testing must_compute validation ===")

# Test config with both hallucination and completeness
test_config = {
"hallucination": {
"detector_name": "default"
},
"completeness": {
"detector_name": "default"
}
}
print(f"Test Config: {test_config}")

# Test valid values
valid_values = ['all_or_none', 'ignore_failures']
print(f"Testing valid must_compute values: {valid_values}")

for value in valid_values:
print(f"Testing valid must_compute value: {value}")
detect = Detect(
values_returned=["context", "generated_text"],
api_key=self.api_key,
config=test_config,
must_compute=value
)
assert detect.must_compute == value
print(f"✅ Successfully validated must_compute value: {value}")

# Test invalid string value
invalid_string_value = "invalid_value"
print(f"Testing invalid must_compute string value: {invalid_string_value}")
try:
Detect(
values_returned=["context", "generated_text"],
api_key=self.api_key,
config=test_config,
must_compute=invalid_string_value
)
print("❌ ERROR: Expected ValueError but none was raised - This should not happen")
assert False, "Expected ValueError for invalid string value"
except ValueError as e:
print(f"✅ Successfully caught ValueError for invalid string: {str(e)}")
assert "`must_compute` must be either 'all_or_none' or 'ignore_failures'" in str(e)

# Test non-string value
non_string_value = 123
print(f"Testing non-string must_compute value: {non_string_value}")
try:
Detect(
values_returned=["context", "generated_text"],
api_key=self.api_key,
config=test_config,
must_compute=non_string_value
)
print("❌ ERROR: Expected ValueError but none was raised - This should not happen")
assert False, "Expected ValueError for non-string value"
except ValueError as e:
print(f"✅ Successfully caught ValueError for non-string: {str(e)}")
assert "`must_compute` must be a string value" in str(e)

# Test default value
print("Testing default must_compute value: default")
detect_default = Detect(
values_returned=["context", "generated_text"],
api_key=self.api_key,
config=test_config
)
assert detect_default.must_compute == 'all_or_none'
print(f"✅ Successfully validated default must_compute value: {detect_default.must_compute}")

print("🎉 Result: must_compute validation working correctly")

def test_must_compute_with_actual_service(self):
"""Test must_compute functionality with actual service calls."""
print("\n=== Testing must_compute with actual service ===")

# Test config with both hallucination and completeness
test_config = {
"hallucination": {
"detector_name": "default"
},
"completeness": {
"detector_name": "default"
}
}
print(f"Test Config: {test_config}")

# Test both must_compute values
for must_compute_value in ['all_or_none', 'ignore_failures']:
print(f"\n--- Testing must_compute: {must_compute_value} ---")

detect = Detect(
values_returned=["context", "generated_text", "user_query"],
api_key=self.api_key,
config=test_config,
must_compute=must_compute_value
)

@detect
def generate_summary(context, query):
generated_text = f"Summary of {context} based on query: {query}"
return context, generated_text, query

# Test data
context = "Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed."
query = "What is machine learning?"

print(f"Input Context: {context}")
print(f"Input Query: {query}")
print(f"Must Compute: {must_compute_value}")

try:
# Call the decorated function
context_ret, generated_text, query_ret, result = generate_summary(context, query)

print(f"✅ API Call Successful!")
print(f"Status Code: {result.status}")
print(f"Generated Text: {generated_text}")

# Display response details
if hasattr(result.detect_response, 'hallucination'):
hallucination = result.detect_response.hallucination
print(f"Hallucination Score: {hallucination.get('score', 'N/A')}")
print(f"Is Hallucinated: {hallucination.get('is_hallucinated', 'N/A')}")

if hasattr(result.detect_response, 'completeness'):
completeness = result.detect_response.completeness
print(f"Completeness Score: {completeness.get('score', 'N/A')}")

# Show the full response structure
print(f"Response Object Type: {type(result.detect_response)}")
if hasattr(result.detect_response, '__dict__'):
print(f"Response Attributes: {list(result.detect_response.__dict__.keys())}")

except Exception as e:
error_message = str(e)
print(f"API Call Result: {error_message}")
print(f"Error Type: {type(e).__name__}")

# For all_or_none, 503 is expected when services are unavailable
if must_compute_value == 'all_or_none' and '503' in error_message:
print("✅ Expected behavior: all_or_none returns 503 when services unavailable")
# For ignore_failures, we expect success or different error handling
elif must_compute_value == 'ignore_failures':
if '503' in error_message:
print("❌ Unexpected: ignore_failures should handle service unavailability")
else:
print("✅ Expected behavior: ignore_failures handled the error appropriately")
else:
print(f"❌ Unexpected error for {must_compute_value}: {error_message}")

print("\n🎉 All must_compute service tests completed!")