Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.70.0"
version = "0.71.0"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
Expand All @@ -21,25 +21,25 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Application Frameworks",
]
license = { file = "LICENSE" }
dependencies = ["numpy~=2.2", "haiway~=0.21.4"]
dependencies = ["numpy~=2.2", "haiway~=0.22"]

[project.urls]
Homepage = "https://miquido.com"
Repository = "https://github.com/miquido/draive.git"

[project.optional-dependencies]
sentencepiece = ["sentencepiece~=0.2"]
cohere = ["cohere~=5.13"]
cohere = ["cohere~=5.15"]
cohere_bedrock = ["cohere~=5.13", "boto3~=1.37"]
openai = ["openai~=1.64", "tiktoken~=0.8"]
anthropic = ["anthropic~=0.47", "tokenizers~=0.21"]
anthropic_bedrock = ["anthropic[bedrock]~=0.47", "tokenizers~=0.21"]
mistral = ["draive[sentencepiece]", "mistralai~=1.5"]
gemini = ["draive[sentencepiece]", "google-genai~=1.10"]
ollama = ["ollama~=0.4"]
bedrock = ["boto3~=1.37"]
vllm = ["openai~=1.64"]
mcp = ["mcp~=1.5"]
openai = ["openai~=1.88", "tiktoken~=0.9"]
anthropic = ["anthropic~=0.54", "tokenizers~=0.21"]
anthropic_bedrock = ["anthropic[bedrock]~=0.54", "tokenizers~=0.21"]
mistral = ["draive[sentencepiece]", "mistralai~=1.8"]
gemini = ["draive[sentencepiece]", "google-genai~=1.20"]
ollama = ["ollama~=0.5"]
bedrock = ["boto3~=1.38"]
vllm = ["openai~=1.88"]
mcp = ["mcp~=1.9"]
opentelemetry = [
"haiway[opentelemetry]",
"opentelemetry-api",
Expand All @@ -52,15 +52,15 @@ dev = [
"pytest~=8.3",
"pytest-asyncio~=0.26",
"pytest-cov~=6.1",
"ruff~=0.11",
"ruff~=0.12",
]

[tool.ruff]
target-version = "py312"
line-length = 100
extend-exclude = [".venv", ".git", ".cache"]
lint.select = ["E", "F", "A", "I", "B", "PL", "W", "C", "RUF", "UP", "NPY201"]
lint.ignore = ["A005"]
lint.ignore = ["A005", "PLC0415"]
lint.pydocstyle.convention = "numpy"

[tool.ruff.lint.per-file-ignores]
Expand Down
16 changes: 16 additions & 0 deletions src/draive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
DefaultValue,
Disposable,
Disposables,
File,
FileAccess,
LoggerObservability,
Missing,
MissingContext,
MissingState,
Observability,
ObservabilityAttribute,
ObservabilityContext,
ObservabilityLevel,
ScopeContext,
ScopeIdentifier,
State,
StateContext,
always,
as_dict,
as_list,
Expand All @@ -28,6 +33,7 @@
asynchronous,
cache,
ctx,
getenv_base64,
getenv_bool,
getenv_float,
getenv_int,
Expand All @@ -36,12 +42,14 @@
load_env,
noop,
not_missing,
process_concurrently,
retry,
setup_logging,
throttle,
timeout,
traced,
when_missing,
without_missing,
)

from draive.agents import (
Expand Down Expand Up @@ -268,6 +276,8 @@
"Disposables",
"Embedded",
"Field",
"File",
"FileAccess",
"GuardrailsAnonymization",
"GuardrailsAnonymizedContent",
"GuardrailsInputModerationException",
Expand Down Expand Up @@ -306,6 +316,7 @@
"LMMToolResponseHandling",
"LMMToolResponses",
"LMMTools",
"LoggerObservability",
"MediaContent",
"MediaData",
"MediaKind",
Expand Down Expand Up @@ -357,12 +368,14 @@
"ResourceTemplate",
"ResourceTemplateDeclaration",
"Resources",
"ScopeContext",
"ScopeIdentifier",
"SelectionException",
"Stage",
"StageException",
"StageState",
"State",
"StateContext",
"TextContent",
"TextEmbedding",
"TextGeneration",
Expand Down Expand Up @@ -390,6 +403,7 @@
"cache",
"choice_completion",
"ctx",
"getenv_base64",
"getenv_bool",
"getenv_float",
"getenv_int",
Expand All @@ -401,6 +415,7 @@
"noop",
"not_missing",
"prepare_instruction",
"process_concurrently",
"prompt",
"refine_instruction",
"resource",
Expand All @@ -416,5 +431,6 @@
"vector_similarity_score",
"vector_similarity_search",
"when_missing",
"without_missing",
"workflow",
)
3 changes: 3 additions & 0 deletions src/draive/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def __eq__(self, other: Self) -> bool:

return self.score == other.score

def __hash__(self) -> int: # explicitly using super to silence warnings
return hash((self.evaluator, self.score, self.threshold))


class EvaluationResult(DataModel):
@classmethod
Expand Down
3 changes: 3 additions & 0 deletions src/draive/evaluation/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ def __ge__(self, other: Any) -> bool:

case _:
return NotImplemented

def __hash__(self) -> int:
return hash((self.value, self.comment))
6 changes: 5 additions & 1 deletion src/draive/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@ def __init__(

@asynccontextmanager
async def lifspan(server: Server) -> AsyncGenerator[Iterable[State]]:
async with disposable as state:
state: Iterable[State] = await disposable.prepare()
try:
yield state

finally:
await disposable.dispose()

self._server = Server[Iterable[State]](
name=name,
version=version,
Expand Down
18 changes: 18 additions & 0 deletions src/draive/parameters/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,24 @@ def __eq__(self, other: Any) -> bool:
for key in self.__PARAMETERS__.keys()
)

def __hash__(self) -> int:
hash_values: list[int] = []
for key in self.__PARAMETERS__.keys():
value: Any = getattr(self, key, MISSING)

# Skip MISSING values to ensure consistent hashing
if value is MISSING:
continue

# Convert to hashable representation
try:
hash_values.append(hash(value))

except TypeError:
continue # skip unhashable

return hash((self.__class__, tuple(hash_values)))
Comment on lines +616 to +632
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

__hash__ breaks for the base DataModel and can silently collide

  1. self.__PARAMETERS__ may be the sentinel MISSING (e.g. on the un-parameterised base class).
    Calling .keys() on it raises AttributeError, so hash(DataModel()) will blow up.

  2. Skipping every unhashable value means two non-equal models can share the same hash, increasing
    the collision rate for sets / dict keys that use these instances.

Proposed minimal fix – guard for the sentinel and include a fallback hashing strategy:

     def __hash__(self) -> int:
-        hash_values: list[int] = []
-        for key in self.__PARAMETERS__.keys():
+        # Base class or invalid state – nothing to hash, but keep instances usable
+        if self.__PARAMETERS__ is MISSING:
+            return hash((self.__class__, ()))
+
+        hash_values: list[int] = []
+        for key in self.__PARAMETERS__:
             value: Any = getattr(self, key, MISSING)
@@
-        return hash((self.__class__, tuple(hash_values)))
+        return hash((self.__class__, tuple(hash_values)))

(Optional) To further reduce collisions, consider recursively converting common container
types (list, set, dict) into hashable equivalents instead of silently skipping them.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __hash__(self) -> int:
hash_values: list[int] = []
for key in self.__PARAMETERS__.keys():
value: Any = getattr(self, key, MISSING)
# Skip MISSING values to ensure consistent hashing
if value is MISSING:
continue
# Convert to hashable representation
try:
hash_values.append(hash(value))
except TypeError:
continue # skip unhashable
return hash((self.__class__, tuple(hash_values)))
def __hash__(self) -> int:
# Base class or invalid state – nothing to hash, but keep instances usable
if self.__PARAMETERS__ is MISSING:
return hash((self.__class__, ()))
hash_values: list[int] = []
for key in self.__PARAMETERS__:
value: Any = getattr(self, key, MISSING)
# Skip MISSING values to ensure consistent hashing
if value is MISSING:
continue
# Convert to hashable representation
try:
hash_values.append(hash(value))
except TypeError:
continue # skip unhashable
return hash((self.__class__, tuple(hash_values)))
🤖 Prompt for AI Agents
In src/draive/parameters/model.py around lines 616 to 632, the __hash__ method
fails when self.__PARAMETERS__ is the sentinel MISSING, causing an
AttributeError on .keys(). To fix this, first check if self.__PARAMETERS__ is
MISSING and handle that case separately to avoid the error. Additionally,
instead of skipping unhashable values which can cause hash collisions, implement
a fallback hashing strategy that converts common container types like list, set,
and dict into hashable equivalents (e.g., tuples or frozensets) before hashing.
This will reduce silent collisions and improve hash uniqueness.


def __contains__(
self,
element: Any,
Expand Down
8 changes: 3 additions & 5 deletions src/draive/stages/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,9 +1144,8 @@ async def stage(
*,
state: StageState,
) -> StageState:
async with ctx_disposables as features:
with ctx.updated(*features):
return await execution(state=state)
async with ctx_disposables:
return await execution(state=state)

case (ctx_state, None):

Expand All @@ -1163,11 +1162,10 @@ async def stage(
*,
state: StageState,
) -> StageState:
async with ctx_disposables as features:
async with ctx_disposables:
with ctx.updated(
ctx_state,
*states,
*features,
# preserve current Processing state by replacing it
ctx.state(Processing),
):
Expand Down
Loading