|
61 | 61 | _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT,
|
62 | 62 | )
|
63 | 63 |
|
| 64 | +_CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet" |
| 65 | +_CLAUDE_3_HAIKU_ENDPOINT = "claude-3-haiku" |
| 66 | +_CLAUDE_3_5_SONNET_ENDPOINT = "claude-3-5-sonnet" |
| 67 | +_CLAUDE_3_OPUS_ENDPOINT = "claude-3-opus" |
| 68 | +_CLAUDE_3_ENDPOINTS = ( |
| 69 | + _CLAUDE_3_SONNET_ENDPOINT, |
| 70 | + _CLAUDE_3_HAIKU_ENDPOINT, |
| 71 | + _CLAUDE_3_5_SONNET_ENDPOINT, |
| 72 | + _CLAUDE_3_OPUS_ENDPOINT, |
| 73 | +) |
| 74 | + |
64 | 75 |
|
65 | 76 | _ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
|
66 | 77 | _ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
|
@@ -1020,3 +1031,225 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator:
|
1020 | 1031 |
|
1021 | 1032 | new_model = self._bqml_model.copy(model_name, replace)
|
1022 | 1033 | return new_model.session.read_gbq_model(model_name)
|
| 1034 | + |
| 1035 | + |
| 1036 | +@log_adapter.class_logger |
| 1037 | +class Claude3TextGenerator(base.BaseEstimator): |
| 1038 | + """Claude3 text generator LLM model. |
| 1039 | +
|
| 1040 | + Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models. |
| 1041 | + https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#grant-permissions |
| 1042 | +
|
| 1043 | + .. note:: |
| 1044 | +
|
| 1045 | + This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the |
| 1046 | + Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" |
| 1047 | + and might have limited support. For more information, see the launch stage descriptions |
| 1048 | + (https://cloud.google.com/products#product-launch-stages). |
| 1049 | +
|
| 1050 | +
|
| 1051 | + .. note:: |
| 1052 | +
|
| 1053 | + The models only availabe in specific regions. Check https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions for details. |
| 1054 | +
|
| 1055 | + Args: |
| 1056 | + model_name (str, Default to "claude-3-sonnet"): |
| 1057 | + The model for natural language tasks. Possible values are "claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet" and "claude-3-opus". |
| 1058 | + "claude-3-sonnet" is Anthropic's dependable combination of skills and speed. It is engineered to be dependable for scaled AI deployments across a variety of use cases. |
| 1059 | + "claude-3-haiku" is Anthropic's fastest, most compact vision and text model for near-instant responses to simple queries, meant for seamless AI experiences mimicking human interactions. |
| 1060 | + "claude-3-5-sonnet" is Anthropic's most powerful AI model and maintains the speed and cost of Claude 3 Sonnet, which is a mid-tier model. |
| 1061 | + "claude-3-opus" is Anthropic's second-most powerful AI model, with strong performance on highly complex tasks. |
| 1062 | + https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#available-claude-models |
| 1063 | + Default to "claude-3-sonnet". |
| 1064 | + session (bigframes.Session or None): |
| 1065 | + BQ session to create the model. If None, use the global default session. |
| 1066 | + connection_name (str or None): |
| 1067 | + Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>. |
| 1068 | + If None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach |
| 1069 | + permission if the connection isn't fully set up. |
| 1070 | + """ |
| 1071 | + |
| 1072 | + def __init__( |
| 1073 | + self, |
| 1074 | + *, |
| 1075 | + model_name: Literal[ |
| 1076 | + "claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus" |
| 1077 | + ] = "claude-3-sonnet", |
| 1078 | + session: Optional[bigframes.Session] = None, |
| 1079 | + connection_name: Optional[str] = None, |
| 1080 | + ): |
| 1081 | + self.model_name = model_name |
| 1082 | + self.session = session or bpd.get_global_session() |
| 1083 | + self._bq_connection_manager = self.session.bqconnectionmanager |
| 1084 | + |
| 1085 | + connection_name = connection_name or self.session._bq_connection |
| 1086 | + self.connection_name = clients.resolve_full_bq_connection_name( |
| 1087 | + connection_name, |
| 1088 | + default_project=self.session._project, |
| 1089 | + default_location=self.session._location, |
| 1090 | + ) |
| 1091 | + |
| 1092 | + self._bqml_model_factory = globals.bqml_model_factory() |
| 1093 | + self._bqml_model: core.BqmlModel = self._create_bqml_model() |
| 1094 | + |
| 1095 | + def _create_bqml_model(self): |
| 1096 | + # Parse and create connection if needed. |
| 1097 | + if not self.connection_name: |
| 1098 | + raise ValueError( |
| 1099 | + "Must provide connection_name, either in constructor or through session options." |
| 1100 | + ) |
| 1101 | + |
| 1102 | + if self._bq_connection_manager: |
| 1103 | + connection_name_parts = self.connection_name.split(".") |
| 1104 | + if len(connection_name_parts) != 3: |
| 1105 | + raise ValueError( |
| 1106 | + f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}." |
| 1107 | + ) |
| 1108 | + self._bq_connection_manager.create_bq_connection( |
| 1109 | + project_id=connection_name_parts[0], |
| 1110 | + location=connection_name_parts[1], |
| 1111 | + connection_id=connection_name_parts[2], |
| 1112 | + iam_role="aiplatform.user", |
| 1113 | + ) |
| 1114 | + |
| 1115 | + if self.model_name not in _CLAUDE_3_ENDPOINTS: |
| 1116 | + raise ValueError( |
| 1117 | + f"Model name {self.model_name} is not supported. We only support {', '.join(_CLAUDE_3_ENDPOINTS)}." |
| 1118 | + ) |
| 1119 | + |
| 1120 | + options = { |
| 1121 | + "endpoint": self.model_name, |
| 1122 | + } |
| 1123 | + |
| 1124 | + return self._bqml_model_factory.create_remote_model( |
| 1125 | + session=self.session, connection_name=self.connection_name, options=options |
| 1126 | + ) |
| 1127 | + |
| 1128 | + @classmethod |
| 1129 | + def _from_bq( |
| 1130 | + cls, session: bigframes.Session, bq_model: bigquery.Model |
| 1131 | + ) -> Claude3TextGenerator: |
| 1132 | + assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" |
| 1133 | + assert "remoteModelInfo" in bq_model._properties |
| 1134 | + assert "endpoint" in bq_model._properties["remoteModelInfo"] |
| 1135 | + assert "connection" in bq_model._properties["remoteModelInfo"] |
| 1136 | + |
| 1137 | + # Parse the remote model endpoint |
| 1138 | + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] |
| 1139 | + model_connection = bq_model._properties["remoteModelInfo"]["connection"] |
| 1140 | + model_endpoint = bqml_endpoint.split("/")[-1] |
| 1141 | + |
| 1142 | + kwargs = utils.retrieve_params_from_bq_model( |
| 1143 | + cls, bq_model, _BQML_PARAMS_MAPPING |
| 1144 | + ) |
| 1145 | + |
| 1146 | + model = cls( |
| 1147 | + **kwargs, |
| 1148 | + session=session, |
| 1149 | + model_name=model_endpoint, |
| 1150 | + connection_name=model_connection, |
| 1151 | + ) |
| 1152 | + model._bqml_model = core.BqmlModel(session, bq_model) |
| 1153 | + return model |
| 1154 | + |
| 1155 | + @property |
| 1156 | + def _bqml_options(self) -> dict: |
| 1157 | + """The model options as they will be set for BQML""" |
| 1158 | + options = { |
| 1159 | + "data_split_method": "NO_SPLIT", |
| 1160 | + } |
| 1161 | + return options |
| 1162 | + |
| 1163 | + def predict( |
| 1164 | + self, |
| 1165 | + X: Union[bpd.DataFrame, bpd.Series], |
| 1166 | + *, |
| 1167 | + max_output_tokens: int = 128, |
| 1168 | + top_k: int = 40, |
| 1169 | + top_p: float = 0.95, |
| 1170 | + ) -> bpd.DataFrame: |
| 1171 | + """Predict the result from input DataFrame. |
| 1172 | +
|
| 1173 | + Args: |
| 1174 | + X (bigframes.dataframe.DataFrame or bigframes.series.Series): |
| 1175 | + Input DataFrame or Series, which contains only one column of prompts. |
| 1176 | + Prompts can include preamble, questions, suggestions, instructions, or examples. |
| 1177 | +
|
| 1178 | + max_output_tokens (int, default 128): |
| 1179 | + Maximum number of tokens that can be generated in the response. Specify a lower value for shorter responses and a higher value for longer responses. |
| 1180 | + A token may be smaller than a word. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words. |
| 1181 | + Default 128. Possible values are in the range [1, 4096]. |
| 1182 | +
|
| 1183 | + top_k (int, default 40): |
| 1184 | + Top-k changes how the model selects tokens for output. A top-k of 1 means the selected token is the most probable among all tokens |
| 1185 | + in the model's vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature). |
| 1186 | + For each token selection step, the top K tokens with the highest probabilities are sampled. Then tokens are further filtered based on topP with the final token selected using temperature sampling. |
| 1187 | + Specify a lower value for less random responses and a higher value for more random responses. |
| 1188 | + Default 40. Possible values [1, 40]. |
| 1189 | +
|
| 1190 | + top_p (float, default 0.95):: |
| 1191 | + Top-p changes how the model selects tokens for output. Tokens are selected from most K (see topK parameter) probable to least until the sum of their probabilities equals the top-p value. |
| 1192 | + For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-p value is 0.5, then the model will select either A or B as the next token (using temperature) |
| 1193 | + and not consider C at all. |
| 1194 | + Specify a lower value for less random responses and a higher value for more random responses. |
| 1195 | + Default 0.95. Possible values [0.0, 1.0]. |
| 1196 | +
|
| 1197 | +
|
| 1198 | + Returns: |
| 1199 | + bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. |
| 1200 | + """ |
| 1201 | + |
| 1202 | + # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models |
| 1203 | + if max_output_tokens not in range(1, 4097): |
| 1204 | + raise ValueError( |
| 1205 | + f"max_output_token must be [1, 4096], but is {max_output_tokens}." |
| 1206 | + ) |
| 1207 | + |
| 1208 | + if top_k not in range(1, 41): |
| 1209 | + raise ValueError(f"top_k must be [1, 40], but is {top_k}.") |
| 1210 | + |
| 1211 | + if top_p < 0.0 or top_p > 1.0: |
| 1212 | + raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.") |
| 1213 | + |
| 1214 | + (X,) = utils.convert_to_dataframe(X) |
| 1215 | + |
| 1216 | + if len(X.columns) != 1: |
| 1217 | + raise ValueError( |
| 1218 | + f"Only support one column as input. {constants.FEEDBACK_LINK}" |
| 1219 | + ) |
| 1220 | + |
| 1221 | + # BQML identified the column by name |
| 1222 | + col_label = cast(blocks.Label, X.columns[0]) |
| 1223 | + X = X.rename(columns={col_label: "prompt"}) |
| 1224 | + |
| 1225 | + options = { |
| 1226 | + "max_output_tokens": max_output_tokens, |
| 1227 | + "top_k": top_k, |
| 1228 | + "top_p": top_p, |
| 1229 | + "flatten_json_output": True, |
| 1230 | + } |
| 1231 | + |
| 1232 | + df = self._bqml_model.generate_text(X, options) |
| 1233 | + |
| 1234 | + if (df[_ML_GENERATE_TEXT_STATUS] != "").any(): |
| 1235 | + warnings.warn( |
| 1236 | + f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", |
| 1237 | + RuntimeWarning, |
| 1238 | + ) |
| 1239 | + |
| 1240 | + return df |
| 1241 | + |
| 1242 | + def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator: |
| 1243 | + """Save the model to BigQuery. |
| 1244 | +
|
| 1245 | + Args: |
| 1246 | + model_name (str): |
| 1247 | + The name of the model. |
| 1248 | + replace (bool, default False): |
| 1249 | + Determine whether to replace if the model already exists. Default to False. |
| 1250 | +
|
| 1251 | + Returns: |
| 1252 | + Claude3TextGenerator: Saved model.""" |
| 1253 | + |
| 1254 | + new_model = self._bqml_model.copy(model_name, replace) |
| 1255 | + return new_model.session.read_gbq_model(model_name) |
0 commit comments