Skip to content

Commit 7050038

Browse files
authored
feat: add ml.llm.Claude3TextGenerator model (#901)
* feat: add ml.llm.Claude3TextGenerator model * add in toc.yml * fix mypy * add models
1 parent d2fc51a commit 7050038

File tree

5 files changed

+313
-2
lines changed

5 files changed

+313
-2
lines changed

bigframes/ml/llm.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@
6161
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT,
6262
)
6363

64+
_CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet"
65+
_CLAUDE_3_HAIKU_ENDPOINT = "claude-3-haiku"
66+
_CLAUDE_3_5_SONNET_ENDPOINT = "claude-3-5-sonnet"
67+
_CLAUDE_3_OPUS_ENDPOINT = "claude-3-opus"
68+
_CLAUDE_3_ENDPOINTS = (
69+
_CLAUDE_3_SONNET_ENDPOINT,
70+
_CLAUDE_3_HAIKU_ENDPOINT,
71+
_CLAUDE_3_5_SONNET_ENDPOINT,
72+
_CLAUDE_3_OPUS_ENDPOINT,
73+
)
74+
6475

6576
_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
6677
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
@@ -1020,3 +1031,225 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator:
10201031

10211032
new_model = self._bqml_model.copy(model_name, replace)
10221033
return new_model.session.read_gbq_model(model_name)
1034+
1035+
1036+
@log_adapter.class_logger
1037+
class Claude3TextGenerator(base.BaseEstimator):
1038+
"""Claude3 text generator LLM model.
1039+
1040+
Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models.
1041+
https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#grant-permissions
1042+
1043+
.. note::
1044+
1045+
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
1046+
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
1047+
and might have limited support. For more information, see the launch stage descriptions
1048+
(https://cloud.google.com/products#product-launch-stages).
1049+
1050+
1051+
.. note::
1052+
1053+
The models only availabe in specific regions. Check https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions for details.
1054+
1055+
Args:
1056+
model_name (str, Default to "claude-3-sonnet"):
1057+
The model for natural language tasks. Possible values are "claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet" and "claude-3-opus".
1058+
"claude-3-sonnet" is Anthropic's dependable combination of skills and speed. It is engineered to be dependable for scaled AI deployments across a variety of use cases.
1059+
"claude-3-haiku" is Anthropic's fastest, most compact vision and text model for near-instant responses to simple queries, meant for seamless AI experiences mimicking human interactions.
1060+
"claude-3-5-sonnet" is Anthropic's most powerful AI model and maintains the speed and cost of Claude 3 Sonnet, which is a mid-tier model.
1061+
"claude-3-opus" is Anthropic's second-most powerful AI model, with strong performance on highly complex tasks.
1062+
https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#available-claude-models
1063+
Default to "claude-3-sonnet".
1064+
session (bigframes.Session or None):
1065+
BQ session to create the model. If None, use the global default session.
1066+
connection_name (str or None):
1067+
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
1068+
If None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach
1069+
permission if the connection isn't fully set up.
1070+
"""
1071+
1072+
def __init__(
1073+
self,
1074+
*,
1075+
model_name: Literal[
1076+
"claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"
1077+
] = "claude-3-sonnet",
1078+
session: Optional[bigframes.Session] = None,
1079+
connection_name: Optional[str] = None,
1080+
):
1081+
self.model_name = model_name
1082+
self.session = session or bpd.get_global_session()
1083+
self._bq_connection_manager = self.session.bqconnectionmanager
1084+
1085+
connection_name = connection_name or self.session._bq_connection
1086+
self.connection_name = clients.resolve_full_bq_connection_name(
1087+
connection_name,
1088+
default_project=self.session._project,
1089+
default_location=self.session._location,
1090+
)
1091+
1092+
self._bqml_model_factory = globals.bqml_model_factory()
1093+
self._bqml_model: core.BqmlModel = self._create_bqml_model()
1094+
1095+
def _create_bqml_model(self):
1096+
# Parse and create connection if needed.
1097+
if not self.connection_name:
1098+
raise ValueError(
1099+
"Must provide connection_name, either in constructor or through session options."
1100+
)
1101+
1102+
if self._bq_connection_manager:
1103+
connection_name_parts = self.connection_name.split(".")
1104+
if len(connection_name_parts) != 3:
1105+
raise ValueError(
1106+
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
1107+
)
1108+
self._bq_connection_manager.create_bq_connection(
1109+
project_id=connection_name_parts[0],
1110+
location=connection_name_parts[1],
1111+
connection_id=connection_name_parts[2],
1112+
iam_role="aiplatform.user",
1113+
)
1114+
1115+
if self.model_name not in _CLAUDE_3_ENDPOINTS:
1116+
raise ValueError(
1117+
f"Model name {self.model_name} is not supported. We only support {', '.join(_CLAUDE_3_ENDPOINTS)}."
1118+
)
1119+
1120+
options = {
1121+
"endpoint": self.model_name,
1122+
}
1123+
1124+
return self._bqml_model_factory.create_remote_model(
1125+
session=self.session, connection_name=self.connection_name, options=options
1126+
)
1127+
1128+
@classmethod
1129+
def _from_bq(
1130+
cls, session: bigframes.Session, bq_model: bigquery.Model
1131+
) -> Claude3TextGenerator:
1132+
assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
1133+
assert "remoteModelInfo" in bq_model._properties
1134+
assert "endpoint" in bq_model._properties["remoteModelInfo"]
1135+
assert "connection" in bq_model._properties["remoteModelInfo"]
1136+
1137+
# Parse the remote model endpoint
1138+
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
1139+
model_connection = bq_model._properties["remoteModelInfo"]["connection"]
1140+
model_endpoint = bqml_endpoint.split("/")[-1]
1141+
1142+
kwargs = utils.retrieve_params_from_bq_model(
1143+
cls, bq_model, _BQML_PARAMS_MAPPING
1144+
)
1145+
1146+
model = cls(
1147+
**kwargs,
1148+
session=session,
1149+
model_name=model_endpoint,
1150+
connection_name=model_connection,
1151+
)
1152+
model._bqml_model = core.BqmlModel(session, bq_model)
1153+
return model
1154+
1155+
@property
1156+
def _bqml_options(self) -> dict:
1157+
"""The model options as they will be set for BQML"""
1158+
options = {
1159+
"data_split_method": "NO_SPLIT",
1160+
}
1161+
return options
1162+
1163+
def predict(
1164+
self,
1165+
X: Union[bpd.DataFrame, bpd.Series],
1166+
*,
1167+
max_output_tokens: int = 128,
1168+
top_k: int = 40,
1169+
top_p: float = 0.95,
1170+
) -> bpd.DataFrame:
1171+
"""Predict the result from input DataFrame.
1172+
1173+
Args:
1174+
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
1175+
Input DataFrame or Series, which contains only one column of prompts.
1176+
Prompts can include preamble, questions, suggestions, instructions, or examples.
1177+
1178+
max_output_tokens (int, default 128):
1179+
Maximum number of tokens that can be generated in the response. Specify a lower value for shorter responses and a higher value for longer responses.
1180+
A token may be smaller than a word. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words.
1181+
Default 128. Possible values are in the range [1, 4096].
1182+
1183+
top_k (int, default 40):
1184+
Top-k changes how the model selects tokens for output. A top-k of 1 means the selected token is the most probable among all tokens
1185+
in the model's vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature).
1186+
For each token selection step, the top K tokens with the highest probabilities are sampled. Then tokens are further filtered based on topP with the final token selected using temperature sampling.
1187+
Specify a lower value for less random responses and a higher value for more random responses.
1188+
Default 40. Possible values [1, 40].
1189+
1190+
top_p (float, default 0.95)::
1191+
Top-p changes how the model selects tokens for output. Tokens are selected from most K (see topK parameter) probable to least until the sum of their probabilities equals the top-p value.
1192+
For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-p value is 0.5, then the model will select either A or B as the next token (using temperature)
1193+
and not consider C at all.
1194+
Specify a lower value for less random responses and a higher value for more random responses.
1195+
Default 0.95. Possible values [0.0, 1.0].
1196+
1197+
1198+
Returns:
1199+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
1200+
"""
1201+
1202+
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
1203+
if max_output_tokens not in range(1, 4097):
1204+
raise ValueError(
1205+
f"max_output_token must be [1, 4096], but is {max_output_tokens}."
1206+
)
1207+
1208+
if top_k not in range(1, 41):
1209+
raise ValueError(f"top_k must be [1, 40], but is {top_k}.")
1210+
1211+
if top_p < 0.0 or top_p > 1.0:
1212+
raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.")
1213+
1214+
(X,) = utils.convert_to_dataframe(X)
1215+
1216+
if len(X.columns) != 1:
1217+
raise ValueError(
1218+
f"Only support one column as input. {constants.FEEDBACK_LINK}"
1219+
)
1220+
1221+
# BQML identified the column by name
1222+
col_label = cast(blocks.Label, X.columns[0])
1223+
X = X.rename(columns={col_label: "prompt"})
1224+
1225+
options = {
1226+
"max_output_tokens": max_output_tokens,
1227+
"top_k": top_k,
1228+
"top_p": top_p,
1229+
"flatten_json_output": True,
1230+
}
1231+
1232+
df = self._bqml_model.generate_text(X, options)
1233+
1234+
if (df[_ML_GENERATE_TEXT_STATUS] != "").any():
1235+
warnings.warn(
1236+
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
1237+
RuntimeWarning,
1238+
)
1239+
1240+
return df
1241+
1242+
def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator:
1243+
"""Save the model to BigQuery.
1244+
1245+
Args:
1246+
model_name (str):
1247+
The name of the model.
1248+
replace (bool, default False):
1249+
Determine whether to replace if the model already exists. Default to False.
1250+
1251+
Returns:
1252+
Claude3TextGenerator: Saved model."""
1253+
1254+
new_model = self._bqml_model.copy(model_name, replace)
1255+
return new_model.session.read_gbq_model(model_name)

bigframes/ml/loader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@
6363
llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator,
6464
llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
6565
llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
66+
llm._CLAUDE_3_HAIKU_ENDPOINT: llm.Claude3TextGenerator,
67+
llm._CLAUDE_3_SONNET_ENDPOINT: llm.Claude3TextGenerator,
68+
llm._CLAUDE_3_5_SONNET_ENDPOINT: llm.Claude3TextGenerator,
69+
llm._CLAUDE_3_OPUS_ENDPOINT: llm.Claude3TextGenerator,
6670
llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator,
6771
llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator,
6872
}
@@ -86,6 +90,7 @@ def from_bq(
8690
imported.XGBoostModel,
8791
llm.PaLM2TextGenerator,
8892
llm.PaLM2TextEmbeddingGenerator,
93+
llm.Claude3TextGenerator,
8994
llm.TextEmbeddingGenerator,
9095
pipeline.Pipeline,
9196
compose.ColumnTransformer,

docs/templates/toc.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@
157157
uid: bigframes.ml.llm.PaLM2TextGenerator
158158
- name: PaLM2TextEmbeddingGenerator
159159
uid: bigframes.ml.llm.PaLM2TextEmbeddingGenerator
160+
- name: Claude3TextGenerator
161+
uid: bigframes.ml.llm.Claude3TextGenerator
160162
name: llm
161163
- items:
162164
- name: metrics

tests/system/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,16 @@ def session() -> Generator[bigframes.Session, None, None]:
145145
session.close() # close generated session at cleanup time
146146

147147

148+
@pytest.fixture(scope="session")
149+
def session_us_east5() -> Generator[bigframes.Session, None, None]:
150+
context = bigframes.BigQueryOptions(
151+
location="us-east5",
152+
)
153+
session = bigframes.Session(context=context)
154+
yield session
155+
session.close() # close generated session at cleanup time
156+
157+
148158
@pytest.fixture(scope="session")
149159
def session_load() -> Generator[bigframes.Session, None, None]:
150160
context = bigframes.BigQueryOptions(location="US", project="bigframes-load-testing")

tests/system/small/ml/test_llm.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tests.system import utils
1919

2020

21-
def test_create_text_generator_model(
21+
def test_create_load_text_generator_model(
2222
palm2_text_generator_model, dataset_id, bq_connection
2323
):
2424
# Model creation doesn't return error
@@ -34,7 +34,7 @@ def test_create_text_generator_model(
3434
assert reloaded_model.connection_name == bq_connection
3535

3636

37-
def test_create_text_generator_32k_model(
37+
def test_create_load_text_generator_32k_model(
3838
palm2_text_generator_32k_model, dataset_id, bq_connection
3939
):
4040
# Model creation doesn't return error
@@ -405,6 +405,67 @@ def test_gemini_text_generator_predict_with_params_success(
405405
assert all(series.str.len() > 20)
406406

407407

408+
# TODO(garrettwu): add tests for claude3.5 sonnet and claude3 opus as they are only available in other regions.
409+
@pytest.mark.parametrize(
410+
"model_name",
411+
("claude-3-sonnet", "claude-3-haiku"),
412+
)
413+
def test_claude3_text_generator_create_load(
414+
dataset_id, model_name, session, bq_connection
415+
):
416+
claude3_text_generator_model = llm.Claude3TextGenerator(
417+
model_name=model_name, connection_name=bq_connection, session=session
418+
)
419+
assert claude3_text_generator_model is not None
420+
assert claude3_text_generator_model._bqml_model is not None
421+
422+
# save, load to ensure configuration was kept
423+
reloaded_model = claude3_text_generator_model.to_gbq(
424+
f"{dataset_id}.temp_text_model", replace=True
425+
)
426+
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
427+
assert reloaded_model.connection_name == bq_connection
428+
assert reloaded_model.model_name == model_name
429+
430+
431+
@pytest.mark.parametrize(
432+
"model_name",
433+
("claude-3-sonnet", "claude-3-haiku"),
434+
)
435+
@pytest.mark.flaky(retries=2)
436+
def test_claude3_text_generator_predict_default_params_success(
437+
llm_text_df, model_name, session, bq_connection
438+
):
439+
claude3_text_generator_model = llm.Claude3TextGenerator(
440+
model_name=model_name, connection_name=bq_connection, session=session
441+
)
442+
df = claude3_text_generator_model.predict(llm_text_df).to_pandas()
443+
assert df.shape == (3, 3)
444+
assert "ml_generate_text_llm_result" in df.columns
445+
series = df["ml_generate_text_llm_result"]
446+
assert all(series.str.len() > 20)
447+
448+
449+
@pytest.mark.parametrize(
450+
"model_name",
451+
("claude-3-sonnet", "claude-3-haiku"),
452+
)
453+
@pytest.mark.flaky(retries=2)
454+
def test_claude3_text_generator_predict_with_params_success(
455+
llm_text_df, model_name, session, bq_connection
456+
):
457+
claude3_text_generator_model = llm.Claude3TextGenerator(
458+
model_name=model_name, connection_name=bq_connection, session=session
459+
)
460+
df = claude3_text_generator_model.predict(
461+
llm_text_df, max_output_tokens=100, top_k=20, top_p=0.5
462+
).to_pandas()
463+
assert df.shape == (3, 3)
464+
assert "ml_generate_text_llm_result" in df.columns
465+
series = df["ml_generate_text_llm_result"]
466+
assert all(series.str.len() > 20)
467+
468+
408469
@pytest.mark.flaky(retries=2)
409470
def test_llm_palm_score(llm_fine_tune_df_default_index):
410471
model = llm.PaLM2TextGenerator(model_name="text-bison")

0 commit comments

Comments
 (0)