diff --git a/vertexai/_model_garden/_model_garden_models.py b/vertexai/_model_garden/_model_garden_models.py index ce587d7184..30f71398e6 100644 --- a/vertexai/_model_garden/_model_garden_models.py +++ b/vertexai/_model_garden/_model_garden_models.py @@ -31,7 +31,8 @@ _SUPPORTED_PUBLISHERS = ["google"] _SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = { - "text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0" + "text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0", + "code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0", } _SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index c1e7524968..bcdc410b8f 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -54,7 +54,9 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str: return tuning_model_id.replace( "text-bison-", "publishers/google/models/text-bison@" ) - raise ValueError(f"Unsupported tuning model ID {tuning_model_id}") + if "/" not in tuning_model_id: + return "publishers/google/models/" + tuning_model_id + return tuning_model_id class _LanguageModel(_model_garden_models._ModelGardenModel): @@ -1007,6 +1009,10 @@ def predict( ) +class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin): + _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE + + ###### Model tuning # Currently, tuning can only work in this location _TUNING_LOCATIONS = ("europe-west4", "us-central1") diff --git a/vertexai/preview/language_models.py b/vertexai/preview/language_models.py index aee71e2aaa..057d73fdcd 100644 --- a/vertexai/preview/language_models.py +++ b/vertexai/preview/language_models.py @@ -16,6 +16,7 @@ from vertexai.language_models._language_models import ( _PreviewChatModel, + _PreviewCodeGenerationModel, _PreviewTextEmbeddingModel, _PreviewTextGenerationModel, ChatMessage, @@ -23,13 +24,13 @@ ChatSession, CodeChatModel, CodeChatSession, - CodeGenerationModel, InputOutputTextPair, TextEmbedding, TextGenerationResponse, ) ChatModel = _PreviewChatModel +CodeGenerationModel = _PreviewCodeGenerationModel TextGenerationModel = _PreviewTextGenerationModel TextEmbeddingModel = _PreviewTextEmbeddingModel