Skip to content

Commit

Permalink
feat: LLM - Support tuning for the code-bison model (preview)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550751052
  • Loading branch information
Ark-kun authored and copybara-github committed Jul 25, 2023
1 parent 75eb777 commit e4b23a2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
3 changes: 2 additions & 1 deletion vertexai/_model_garden/_model_garden_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
_SUPPORTED_PUBLISHERS = ["google"]

_SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = {
"text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0"
"text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0",
"code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0",
}

_SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset(
Expand Down
8 changes: 7 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str:
return tuning_model_id.replace(
"text-bison-", "publishers/google/models/text-bison@"
)
raise ValueError(f"Unsupported tuning model ID {tuning_model_id}")
if "/" not in tuning_model_id:
return "publishers/google/models/" + tuning_model_id
return tuning_model_id


class _LanguageModel(_model_garden_models._ModelGardenModel):
Expand Down Expand Up @@ -1007,6 +1009,10 @@ def predict(
)


class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE


###### Model tuning
# Currently, tuning can only work in this location
_TUNING_LOCATIONS = ("europe-west4", "us-central1")
Expand Down
3 changes: 2 additions & 1 deletion vertexai/preview/language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

from vertexai.language_models._language_models import (
_PreviewChatModel,
_PreviewCodeGenerationModel,
_PreviewTextEmbeddingModel,
_PreviewTextGenerationModel,
ChatMessage,
ChatModel,
ChatSession,
CodeChatModel,
CodeChatSession,
CodeGenerationModel,
InputOutputTextPair,
TextEmbedding,
TextGenerationResponse,
)

ChatModel = _PreviewChatModel
CodeGenerationModel = _PreviewCodeGenerationModel
TextGenerationModel = _PreviewTextGenerationModel
TextEmbeddingModel = _PreviewTextEmbeddingModel

Expand Down

0 comments on commit e4b23a2

Please sign in to comment.