Skip to content

Commit 95f8d7b

Browse files
authored
chore!: moving to pydantic2 (#1394)
1 parent af3f85b commit 95f8d7b

16 files changed

+45
-249
lines changed

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ dependencies = [
55
"datasets",
66
"tiktoken",
77
"langchain",
8-
"langchain-core<0.3",
8+
"langchain-core",
99
"langchain-community",
1010
"langchain_openai",
11-
"openai>1",
12-
"pysbd>=0.3.4",
1311
"nest-asyncio",
1412
"appdirs",
13+
"pydantic>=2",
14+
"openai>1",
15+
"pysbd>=0.3.4",
1516
]
1617
dynamic = ["version", "readme"]
1718

src/ragas/_analytics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import requests
1111
from appdirs import user_data_dir
12-
from langchain_core.pydantic_v1 import BaseModel, Field
12+
from pydantic import BaseModel, Field
1313

1414
from ragas.utils import get_debug_mode
1515

src/ragas/cost.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from langchain_core.callbacks.base import BaseCallbackHandler
55
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
6-
from langchain_core.pydantic_v1 import BaseModel
6+
from pydantic import BaseModel
77

88
from ragas.utils import get_from_dict
99

@@ -39,7 +39,9 @@ def cost(
3939
+ self.output_tokens * cost_per_output_token
4040
)
4141

42-
def __eq__(self, other: "TokenUsage") -> bool:
42+
def __eq__(self, other: object) -> bool:
43+
if not isinstance(other, TokenUsage):
44+
return False
4345
return (
4446
self.input_tokens == other.input_tokens
4547
and self.output_tokens == other.output_tokens

src/ragas/llms/output_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from langchain_core.exceptions import OutputParserException
66
from langchain_core.output_parsers import PydanticOutputParser
7-
from langchain_core.pydantic_v1 import BaseModel
7+
from pydantic import BaseModel
88

99
from ragas.llms import BaseRagasLLM
1010
from ragas.llms.prompt import Prompt, PromptValue

src/ragas/metrics/_answer_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66

77
import numpy as np
8-
from langchain_core.pydantic_v1 import BaseModel
8+
from pydantic import BaseModel
99

1010
from ragas.dataset_schema import SingleTurnSample
1111
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions

src/ragas/metrics/_answer_relevance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66

77
import numpy as np
8-
from langchain_core.pydantic_v1 import BaseModel
8+
from pydantic import BaseModel
99

1010
from ragas.dataset_schema import SingleTurnSample
1111
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions

src/ragas/metrics/_context_entities_recall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class ContextEntitiesResponse(BaseModel):
2424

2525

2626
_output_instructions = get_json_format_instructions(
27-
pydantic_object=ContextEntitiesResponse
27+
pydantic_object=ContextEntitiesResponse # type: ignore
2828
)
2929
_output_parser = RagasoutputParser(pydantic_object=ContextEntitiesResponse)
3030

src/ragas/metrics/_context_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ContextPrecisionVerifications(BaseModel):
3333

3434

3535
_verification_output_instructions = get_json_format_instructions(
36-
ContextPrecisionVerification
36+
ContextPrecisionVerification # type: ignore
3737
)
3838
_output_parser = RagasoutputParser(pydantic_object=ContextPrecisionVerification)
3939

src/ragas/metrics/_context_recall.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66

77
import numpy as np
8-
from langchain_core.pydantic_v1 import BaseModel
8+
from pydantic import BaseModel, RootModel
99

1010
from ragas.dataset_schema import SingleTurnSample
1111
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
@@ -29,11 +29,11 @@ class ContextRecallClassificationAnswer(BaseModel):
2929
reason: str
3030

3131

32-
class ContextRecallClassificationAnswers(BaseModel):
33-
__root__: t.List[ContextRecallClassificationAnswer]
32+
class ContextRecallClassificationAnswers(RootModel):
33+
root: t.List[ContextRecallClassificationAnswer]
3434

3535
def dicts(self) -> t.List[t.Dict]:
36-
return self.dict()["__root__"]
36+
return self.model_dump()
3737

3838

3939
_classification_output_instructions = get_json_format_instructions(

src/ragas/metrics/_faithfulness.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass, field
77

88
import numpy as np
9-
from langchain_core.pydantic_v1 import BaseModel, Field
9+
from pydantic import BaseModel, Field, RootModel
1010

1111
from ragas.dataset_schema import SingleTurnSample
1212
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
@@ -39,11 +39,8 @@ class Statements(BaseModel):
3939
simpler_statements: t.List[str] = Field(..., description="the simpler statements")
4040

4141

42-
class StatementsAnswers(BaseModel):
43-
__root__: t.List[Statements]
44-
45-
def dicts(self) -> t.List[t.Dict]:
46-
return self.dict()["__root__"]
42+
class StatementsAnswers(RootModel):
43+
root: t.List[Statements]
4744

4845

4946
_statements_output_instructions = get_json_format_instructions(StatementsAnswers)
@@ -79,7 +76,7 @@ def dicts(self) -> t.List[t.Dict]:
7976
],
8077
},
8178
]
82-
).dicts(),
79+
).model_dump(),
8380
}
8481
],
8582
input_keys=["question", "answer", "sentences"],
@@ -94,11 +91,11 @@ class StatementFaithfulnessAnswer(BaseModel):
9491
verdict: int = Field(..., description="the verdict(0/1) of the faithfulness.")
9592

9693

97-
class StatementFaithfulnessAnswers(BaseModel):
98-
__root__: t.List[StatementFaithfulnessAnswer]
94+
class StatementFaithfulnessAnswers(RootModel):
95+
root: t.List[StatementFaithfulnessAnswer]
9996

100-
def dicts(self) -> t.List[t.Dict]:
101-
return self.dict()["__root__"]
97+
def dicts(self):
98+
return self.model_dump()
10299

103100

104101
_faithfulness_output_instructions = get_json_format_instructions(
@@ -144,20 +141,20 @@ def dicts(self) -> t.List[t.Dict]:
144141
"verdict": 0,
145142
},
146143
]
147-
).dicts(),
144+
).model_dump(),
148145
},
149146
{
150147
"context": """Photosynthesis is a process used by plants, algae, and certain bacteria to convert light energy into chemical energy.""",
151148
"statements": ["Albert Einstein was a genius."],
152-
"answer": StatementFaithfulnessAnswers.parse_obj(
149+
"answer": StatementFaithfulnessAnswers.model_validate(
153150
[
154151
{
155152
"statement": "Albert Einstein was a genius.",
156153
"reason": "The context and statement are unrelated",
157154
"verdict": 0,
158155
}
159156
]
160-
).dicts(),
157+
).model_dump(),
161158
},
162159
],
163160
input_keys=["context", "statements"],
@@ -237,9 +234,9 @@ def _create_statements_prompt(self, row: t.Dict) -> PromptValue:
237234
def _compute_score(self, answers: StatementFaithfulnessAnswers):
238235
# check the verdicts and compute the score
239236
faithful_statements = sum(
240-
1 if answer.verdict else 0 for answer in answers.__root__
237+
1 if answer.verdict else 0 for answer in answers.model_dump()
241238
)
242-
num_statements = len(answers.__root__)
239+
num_statements = len(answers.model_dump())
243240
if num_statements:
244241
score = faithful_statements / num_statements
245242
else:
@@ -272,7 +269,7 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
272269
if statements is None:
273270
return np.nan
274271

275-
statements = [item["simpler_statements"] for item in statements.dicts()]
272+
statements = [item["simpler_statements"] for item in statements.model_dump()]
276273
statements = [item for sublist in statements for item in sublist]
277274

278275
assert isinstance(statements, t.List), "statements must be a list"
@@ -295,7 +292,7 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
295292
]
296293

297294
faithfulness_list = [
298-
faith.dicts() for faith in faithfulness_list if faith is not None
295+
faith.model_dump() for faith in faithfulness_list if faith is not None
299296
]
300297

301298
if faithfulness_list:
@@ -385,7 +382,7 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
385382
if statements is None:
386383
return np.nan
387384

388-
statements = [item["simpler_statements"] for item in statements.dicts()]
385+
statements = [item["simpler_statements"] for item in statements.model_dump()]
389386
statements = [item for sublist in statements for item in sublist]
390387

391388
assert isinstance(statements, t.List), "statements must be a list"

0 commit comments

Comments
 (0)