Skip to content

Commit

Permalink
fix: core: Include in json output also fields set outside the constru…
Browse files Browse the repository at this point in the history
…ctor (#21342)
  • Loading branch information
nfcampos authored May 6, 2024
1 parent ac14f17 commit 6f17158
Show file tree
Hide file tree
Showing 9 changed files with 477 additions and 349 deletions.
28 changes: 20 additions & 8 deletions libs/core/langchain_core/load/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from typing_extensions import NotRequired

from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
from langchain_core.pydantic_v1 import BaseModel


class BaseSerialized(TypedDict):
Expand Down Expand Up @@ -114,12 +114,6 @@ def __repr_args__(self) -> Any:
if (k not in self.__fields__ or try_neq_default(v, k, self))
]

_lc_kwargs = PrivateAttr(default_factory=dict)

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._lc_kwargs = kwargs

def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.is_lc_serializable():
return self.to_json_not_implemented()
Expand All @@ -128,8 +122,9 @@ def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self._lc_kwargs.items()
for k, v in self
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
and _is_field_useful(self, k, v)
}

# Merge the lc_secrets and lc_attributes from every class in the MRO
Expand Down Expand Up @@ -186,6 +181,23 @@ def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)


def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
"""Check if a field is useful as a constructor argument.
Args:
inst: The instance.
key: The key.
value: The value.
Returns:
Whether the field is useful.
"""
field = inst.__fields__.get(key)
if not field:
return False
return field.required is True or value or field.get_default() != value


def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str]
) -> Dict[Any, Any]:
Expand Down
10 changes: 10 additions & 0 deletions libs/core/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence
from uuid import UUID

import pytest
from pytest import Config, Function, Parser
from pytest_mock import MockerFixture


def pytest_addoption(parser: Parser) -> None:
Expand Down Expand Up @@ -85,3 +87,11 @@ def test_something():
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)


@pytest.fixture()
def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
side_effect = (
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
)
return mocker.patch("uuid.uuid4", side_effect=side_effect)
2 changes: 2 additions & 0 deletions libs/core/tests/unit_tests/messages/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_serdes_message() -> None:
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"type": "ai",
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [
Expand All @@ -46,6 +47,7 @@ def test_serdes_message_chunk() -> None:
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessageChunk"],
"kwargs": {
"type": "AIMessageChunk",
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
]
}
},
"middle": [],
"last": {
"lc": 1,
"type": "constructor",
Expand Down Expand Up @@ -109,8 +108,7 @@
"buz"
],
"template": "what did baz say to {buz}",
"template_format": "f-string",
"partial_variables": {}
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
Expand Down Expand Up @@ -151,7 +149,6 @@
]
}
},
"middle": [],
"last": {
"lc": 1,
"type": "not_implemented",
Expand Down Expand Up @@ -200,8 +197,7 @@
}
]
}
},
"name": null
}
},
"name": "RunnableSequence",
"graph": {
Expand Down Expand Up @@ -284,8 +280,7 @@
"buz"
],
"template": "what did baz say to {buz}",
"template_format": "f-string",
"partial_variables": {}
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
Expand Down Expand Up @@ -326,7 +321,6 @@
]
}
},
"middle": [],
"last": {
"lc": 1,
"type": "not_implemented",
Expand Down Expand Up @@ -375,8 +369,7 @@
}
]
}
},
"name": null
}
},
"name": "RunnableSequence",
"graph": {
Expand Down Expand Up @@ -445,8 +438,7 @@
],
"repr": "<class 'Exception'>"
}
],
"exception_key": null
]
},
"name": "RunnableWithFallbacks",
"graph": {
Expand Down Expand Up @@ -486,8 +478,7 @@
}
]
}
},
"name": null
}
},
"name": "RunnableSequence",
"graph": {
Expand Down Expand Up @@ -579,11 +570,7 @@
"runnable",
"RunnablePassthrough"
],
"kwargs": {
"func": null,
"afunc": null,
"input_type": null
},
"kwargs": {},
"name": "RunnablePassthrough",
"graph": {
"nodes": [
Expand Down Expand Up @@ -664,7 +651,6 @@
]
}
},
"middle": [],
"last": {
"lc": 1,
"type": "constructor",
Expand Down Expand Up @@ -750,8 +736,7 @@
}
]
}
},
"name": null
}
},
"name": "RunnableSequence",
"graph": {
Expand Down Expand Up @@ -933,8 +918,7 @@
],
"repr": "<class 'Exception'>"
}
],
"exception_key": null
]
},
"name": "RunnableWithFallbacks",
"graph": {
Expand Down Expand Up @@ -1148,8 +1132,7 @@
],
"repr": "<class 'Exception'>"
}
],
"exception_key": null
]
},
"name": "RunnableWithFallbacks",
"graph": {
Expand Down
Loading

0 comments on commit 6f17158

Please sign in to comment.