Skip to content

Commit

Permalink
fix: GenAI - Fixed from_dict methods
Browse files Browse the repository at this point in the history
The parsing of dictionaries was standardized and improved. Parsing dictionaries via `json_formt.ParseDict` allows automatically handling renamed fields like `type_`, `format_`, `enum_` and also camelCase keys like `minItems` that are often used in Json Schema schemas. The `json_formt.ParseDict` method succeeds in such cases while the Proto Message constructors fail.

However there is an issue with enums which makes full migration to `json_schema.ParseDict` non-trivial. Protobuf's ParseDict treats enum names as case sensitive. This is mostly problematic for Shema messages where Python protos use uppercase enum names like `"OBJECT"` while Json Schemas usually use lowercase types like `"object"`. I've created a fix for the protobuf library, but it will take a while before it's released everywhere. So, enums need to be fixed. But such enums can happen in multiple places in a deeply nested dictionary structure. We need to carefully fix all enum casing issues in nested dicts before we can use `ParseDict`.

PiperOrigin-RevId: 680480143
  • Loading branch information
Ark-kun authored and copybara-github committed Sep 30, 2024
1 parent 29dec74 commit 3090812
Showing 1 changed file with 56 additions and 21 deletions.
77 changes: 56 additions & 21 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
Literal,
Optional,
Sequence,
Type,
TypeVar,
Union,
overload,
TYPE_CHECKING,
Expand Down Expand Up @@ -66,6 +68,10 @@
except ImportError:
PIL_Image = None


T = TypeVar("T")


# Re-exporting some GAPIC types

# GAPIC types used in request
Expand Down Expand Up @@ -1627,8 +1633,9 @@ def __init__(
if response_schema is None:
raw_schema = None
else:
gapic_schema_dict = _convert_schema_dict_to_gapic(response_schema)
raw_schema = aiplatform_types.Schema(gapic_schema_dict)
raw_schema = FunctionDeclaration(
name="tmp", parameters=response_schema
)._raw_function_declaration.parameters
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
Expand Down Expand Up @@ -1660,8 +1667,12 @@ def _from_gapic(

@classmethod
def from_dict(cls, generation_config_dict: Dict[str, Any]) -> "GenerationConfig":
raw_generation_config = gapic_content_types.GenerationConfig(
generation_config_dict
generation_config_dict = copy.deepcopy(generation_config_dict)
response_schema = generation_config_dict.get("response_schema")
if response_schema:
_fix_schema_dict_for_gapic_in_place(response_schema)
raw_generation_config = _dict_to_proto(
gapic_content_types.GenerationConfig, generation_config_dict
)
return cls._from_gapic(raw_generation_config=raw_generation_config)

Expand Down Expand Up @@ -1872,12 +1883,11 @@ def _from_gapic(
@classmethod
def from_dict(cls, tool_dict: Dict[str, Any]) -> "Tool":
tool_dict = copy.deepcopy(tool_dict)
function_declarations = tool_dict["function_declarations"]
for function_declaration in function_declarations:
function_declaration["parameters"] = _convert_schema_dict_to_gapic(
function_declaration["parameters"]
)
raw_tool = gapic_tool_types.Tool(tool_dict)
for function_declaration in tool_dict.get("function_declarations") or []:
parameters = function_declaration.get("parameters")
if parameters:
_fix_schema_dict_for_gapic_in_place(parameters)
raw_tool = _dict_to_proto(aiplatform_types.Tool, tool_dict)
return cls._from_gapic(raw_tool=raw_tool)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -2035,8 +2045,9 @@ def __init__(
description: Description and purpose of the function.
Model uses it to decide how and whether to call the function.
"""
gapic_schema_dict = _convert_schema_dict_to_gapic(parameters)
raw_schema = aiplatform_types.Schema(gapic_schema_dict)
parameters = copy.deepcopy(parameters)
_fix_schema_dict_for_gapic_in_place(parameters)
raw_schema = _dict_to_proto(aiplatform_types.Schema, parameters)
self._raw_function_declaration = gapic_tool_types.FunctionDeclaration(
name=name, description=description, parameters=raw_schema
)
Expand All @@ -2052,6 +2063,7 @@ def __repr__(self) -> str:
return self._raw_function_declaration.__repr__()


# TODO: Remove this function once Reasoning Engines moves away from it.
def _convert_schema_dict_to_gapic(schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Converts a JsonSchema to a dict that the GAPIC Schema class accepts."""
gapic_schema_dict = copy.copy(schema_dict)
Expand All @@ -2070,6 +2082,20 @@ def _convert_schema_dict_to_gapic(schema_dict: Dict[str, Any]) -> Dict[str, Any]
return gapic_schema_dict


def _fix_schema_dict_for_gapic_in_place(schema_dict: Dict[str, Any]) -> None:
"""Converts a JsonSchema to a dict that the Schema proto class accepts."""
schema_dict["type"] = schema_dict["type"].upper()

items_schema = schema_dict.get("items")
if items_schema:
_fix_schema_dict_for_gapic_in_place(items_schema)

properties = schema_dict.get("properties")
if properties:
for property_schema in properties.values():
_fix_schema_dict_for_gapic_in_place(property_schema)


class CallableFunctionDeclaration(FunctionDeclaration):
"""A function declaration plus a function."""

Expand Down Expand Up @@ -2139,8 +2165,9 @@ def _from_gapic(

@classmethod
def from_dict(cls, response_dict: Dict[str, Any]) -> "GenerationResponse":
raw_response = gapic_prediction_service_types.GenerateContentResponse()
json_format.ParseDict(response_dict, raw_response._pb)
raw_response = _dict_to_proto(
gapic_prediction_service_types.GenerateContentResponse, response_dict
)
return cls._from_gapic(raw_response=raw_response)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -2209,8 +2236,7 @@ def _from_gapic(cls, raw_candidate: gapic_content_types.Candidate) -> "Candidate

@classmethod
def from_dict(cls, candidate_dict: Dict[str, Any]) -> "Candidate":
raw_candidate = gapic_content_types.Candidate()
json_format.ParseDict(candidate_dict, raw_candidate._pb)
raw_candidate = _dict_to_proto(gapic_content_types.Candidate, candidate_dict)
return cls._from_gapic(raw_candidate=raw_candidate)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -2310,8 +2336,7 @@ def _from_gapic(cls, raw_content: gapic_content_types.Content) -> "Content":

@classmethod
def from_dict(cls, content_dict: Dict[str, Any]) -> "Content":
raw_content = gapic_content_types.Content()
json_format.ParseDict(content_dict, raw_content._pb)
raw_content = _dict_to_proto(gapic_content_types.Content, content_dict)
return cls._from_gapic(raw_content=raw_content)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -2381,8 +2406,7 @@ def _from_gapic(cls, raw_part: gapic_content_types.Part) -> "Part":

@classmethod
def from_dict(cls, part_dict: Dict[str, Any]) -> "Part":
raw_part = gapic_content_types.Part()
json_format.ParseDict(part_dict, raw_part._pb)
raw_part = _dict_to_proto(gapic_content_types.Part, part_dict)
return cls._from_gapic(raw_part=raw_part)

def __repr__(self) -> str:
Expand Down Expand Up @@ -2510,7 +2534,9 @@ def _from_gapic(

@classmethod
def from_dict(cls, safety_setting_dict: Dict[str, Any]) -> "SafetySetting":
raw_safety_setting = gapic_content_types.SafetySetting(safety_setting_dict)
raw_safety_setting = _dict_to_proto(
aiplatform_types.SafetySetting, safety_setting_dict
)
return cls._from_gapic(raw_safety_setting=raw_safety_setting)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -2760,6 +2786,15 @@ def _proto_to_dict(message) -> Dict[str, Any]:
)


def _dict_to_proto(message_type: Type[T], message_dict: Dict[str, Any]) -> T:
"""Converts a dictionary to a proto-plus protobuf message."""
# We cannot just use `message = message_type(message_dict)` because
# it fails for classes where GAPIC has renamed proto fields.
message = message_type()
json_format.ParseDict(message_dict, message._pb)
return message


def _dict_to_pretty_string(d: dict) -> str:
"""Format dict as a pretty-printed JSON string."""
return json.dumps(d, indent=2)
Expand Down

0 comments on commit 3090812

Please sign in to comment.