|
29 | 29 | import pandas as pd
|
30 | 30 | from tqdm import tqdm
|
31 | 31 |
|
| 32 | +from . import _evals_constant |
32 | 33 | from . import _evals_data_converters
|
33 | 34 | from . import _evals_metric_handlers
|
34 | 35 | from . import _evals_utils
|
@@ -370,6 +371,14 @@ def _run_litellm_inference(
|
370 | 371 | return responses
|
371 | 372 |
|
372 | 373 |
|
| 374 | +def _is_litellm_vertex_maas_model(model: str) -> bool: |
| 375 | + """Checks if the model is a Vertex MAAS model to be handled by LiteLLM.""" |
| 376 | + return any( |
| 377 | + model.startswith(prefix) |
| 378 | + for prefix in _evals_constant.SUPPORTED_VERTEX_MAAS_MODEL_PREFIXES |
| 379 | + ) |
| 380 | + |
| 381 | + |
373 | 382 | def _is_litellm_model(model: str) -> bool:
|
374 | 383 | """Checks if the model name corresponds to a valid LiteLLM model name."""
|
375 | 384 | return model in litellm.utils.get_valid_models(model)
|
@@ -431,47 +440,64 @@ def _run_inference_internal(
|
431 | 440 | }
|
432 | 441 | processed_responses.append(json.dumps(error_payload))
|
433 | 442 | responses = processed_responses
|
| 443 | + elif callable(model): |
| 444 | + logger.info("Running inference with custom callable function.") |
| 445 | + custom_responses_raw = _run_custom_inference( |
| 446 | + model_fn=model, prompt_dataset=prompt_dataset |
| 447 | + ) |
| 448 | + processed_custom_responses = [] |
| 449 | + for resp_item in custom_responses_raw: |
| 450 | + if isinstance(resp_item, str): |
| 451 | + processed_custom_responses.append(resp_item) |
| 452 | + elif isinstance(resp_item, dict) and "error" in resp_item: |
| 453 | + processed_custom_responses.append(json.dumps(resp_item)) |
| 454 | + else: |
| 455 | + try: |
| 456 | + processed_custom_responses.append(json.dumps(resp_item)) |
| 457 | + except TypeError: |
| 458 | + processed_custom_responses.append(str(resp_item)) |
| 459 | + responses = processed_custom_responses |
434 | 460 | elif isinstance(model, str):
|
435 | 461 | if litellm is None:
|
436 | 462 | raise ImportError(
|
437 |
| - "The 'litellm' library is required to use third-party models." |
| 463 | + "The 'litellm' library is required to use this model." |
438 | 464 | " Please install it using 'pip install"
|
439 | 465 | " google-cloud-aiplatform[evaluation]'."
|
440 | 466 | )
|
441 |
| - if _is_litellm_model(model): |
442 |
| - logger.info("Running inference with LiteLLM for model: %s", model) |
443 |
| - raw_responses = _run_litellm_inference( # type: ignore[assignment] |
444 |
| - model=model, prompt_dataset=prompt_dataset |
| 467 | + |
| 468 | + processed_model_id = model |
| 469 | + if model.startswith("vertex_ai/"): |
| 470 | + # Already correctly prefixed for LiteLLM's Vertex AI provider |
| 471 | + pass |
| 472 | + elif _is_litellm_vertex_maas_model(model): |
| 473 | + processed_model_id = f"vertex_ai/{model}" |
| 474 | + logger.info( |
| 475 | + "Detected Vertex AI Model Garden managed MaaS model. " |
| 476 | + "Using LiteLLM ID: %s", |
| 477 | + processed_model_id, |
445 | 478 | )
|
446 |
| - responses = [json.dumps(resp) for resp in raw_responses] |
| 479 | + elif _is_litellm_model(model): |
| 480 | + # Other LiteLLM supported model |
| 481 | + logger.info("Running inference with LiteLLM for model: %s", model) |
447 | 482 | else:
|
| 483 | + # Unsupported model string |
448 | 484 | raise TypeError(
|
449 | 485 | f"Unsupported string model name: {model}. Expecting a Gemini model"
|
450 |
| - " name (e.g., 'gemini-2.5-pro', 'projects/.../models/...') or a" |
| 486 | + " name (e.g., 'gemini-1.5-pro', 'projects/.../models/...') or a" |
451 | 487 | " LiteLLM supported model name (e.g., 'openai/gpt-4o')."
|
452 | 488 | " If using a third-party model via LiteLLM, ensure the"
|
453 | 489 | " necessary environment variables are set (e.g., for OpenAI:"
|
454 | 490 | " `os.environ['OPENAI_API_KEY'] = 'Your API Key'`). See"
|
455 | 491 | " LiteLLM documentation for details:"
|
456 | 492 | " https://docs.litellm.ai/docs/set_keys#environment-variables"
|
457 | 493 | )
|
458 |
| - elif callable(model): |
459 |
| - logger.info("Running inference with custom callable function.") |
460 |
| - custom_responses_raw = _run_custom_inference( |
461 |
| - model_fn=model, prompt_dataset=prompt_dataset |
| 494 | + |
| 495 | + logger.info("Running inference via LiteLLM for model: %s", processed_model_id) |
| 496 | + raw_responses = _run_litellm_inference( |
| 497 | + model=processed_model_id, prompt_dataset=prompt_dataset |
462 | 498 | )
|
463 |
| - processed_custom_responses = [] |
464 |
| - for resp_item in custom_responses_raw: |
465 |
| - if isinstance(resp_item, str): |
466 |
| - processed_custom_responses.append(resp_item) |
467 |
| - elif isinstance(resp_item, dict) and "error" in resp_item: |
468 |
| - processed_custom_responses.append(json.dumps(resp_item)) |
469 |
| - else: |
470 |
| - try: |
471 |
| - processed_custom_responses.append(json.dumps(resp_item)) |
472 |
| - except TypeError: |
473 |
| - processed_custom_responses.append(str(resp_item)) |
474 |
| - responses = processed_custom_responses |
| 499 | + responses = [json.dumps(resp) for resp in raw_responses] |
| 500 | + |
475 | 501 | else:
|
476 | 502 | raise TypeError(
|
477 | 503 | f"Unsupported model type: {type(model)}. Expecting string (model"
|
|
0 commit comments