Skip to content

fix: ensure version values are strings before Tag comparison (#40) #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions langgraph/checkpoint/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,22 @@ def _dump_blobs(
storage_safe_thread_id = to_storage_safe_id(thread_id)
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)

# Ensure all versions are converted to strings to avoid TypeError with Tag filters
str_versions = {k: str(v) for k, v in versions.items()}

return [
(
BaseRedisSaver._make_redis_checkpoint_blob_key(
storage_safe_thread_id,
storage_safe_checkpoint_ns,
k,
cast(str, ver),
str_versions[k], # Use the string version
),
{
"thread_id": storage_safe_thread_id,
"checkpoint_ns": storage_safe_checkpoint_ns,
"channel": k,
"version": cast(str, ver),
"version": str_versions[k], # Use the string version
"type": (
self._get_type_and_blob(values[k])[0]
if k in values
Expand All @@ -311,7 +314,7 @@ def _dump_blobs(
),
},
)
for k, ver in versions.items()
for k in str_versions.keys()
]

def _dump_writes(
Expand Down
103 changes: 103 additions & 0 deletions tests/test_numeric_version_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Test for the fix to issue #40 - Fixing numeric version handling with Tag type.
"""

from contextlib import contextmanager

import pytest
from langgraph.checkpoint.base import empty_checkpoint
from redis import Redis

from langgraph.checkpoint.redis import RedisSaver


@pytest.fixture(autouse=True)
async def clear_test_redis(redis_url: str) -> None:
"""Clear Redis before each test."""
client = Redis.from_url(redis_url)
try:
client.flushall()
finally:
client.close()


@contextmanager
def patched_redis_saver(redis_url):
"""
Create a RedisSaver with a patched _dump_blobs method to fix the issue.
This demonstrates the fix approach.
"""
original_dump_blobs = RedisSaver._dump_blobs

def patched_dump_blobs(self, thread_id, checkpoint_ns, values, versions):
"""
Patched version of _dump_blobs that ensures version is a string.
"""
# Convert version to string in versions dictionary
string_versions = {k: str(v) for k, v in versions.items()}

# Call the original method with string versions
return original_dump_blobs(
self, thread_id, checkpoint_ns, values, string_versions
)

# Apply the patch
RedisSaver._dump_blobs = patched_dump_blobs

try:
# Create the saver with patched method
saver = RedisSaver(redis_url)
saver.setup()
yield saver
finally:
# Restore the original method
RedisSaver._dump_blobs = original_dump_blobs
# Clean up
if saver._owns_its_client:
saver._redis.close()


def test_numeric_version_fix(redis_url: str) -> None:
"""
Test that demonstrates the fix for issue #40.

This shows how to handle numeric versions correctly by ensuring
they are converted to strings before being used with Tag.
"""
# Use our patched version that converts numeric versions to strings
with patched_redis_saver(redis_url) as saver:
# Set up a basic config
config = {
"configurable": {
"thread_id": "thread-numeric-version-fix",
"checkpoint_ns": "",
}
}

# Create a basic checkpoint
checkpoint = empty_checkpoint()

# Store the checkpoint with our patched method
saved_config = saver.put(
config, checkpoint, {}, {"test_channel": 1}
) # Numeric version

# Get the checkpoint ID from the saved config
thread_id = saved_config["configurable"]["thread_id"]
checkpoint_ns = saved_config["configurable"].get("checkpoint_ns", "")

# Now query the data - this should work with the fix
query = f"@channel:{{test_channel}}"

# This should not raise an error now with our patch
results = saver.checkpoint_blobs_index.search(query)

# Verify we can find the data
assert len(results.docs) > 0

# Load one document and verify the version is a string
doc = results.docs[0]
data = saver._redis.json().get(doc.id)

# The key test: version should be a string even though we passed a numeric value
assert isinstance(data["version"], str)
73 changes: 73 additions & 0 deletions tests/test_numeric_version_issue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Test for issue #40 - Error when comparing numeric version with Tag type.
"""

import pytest
from langgraph.checkpoint.base import empty_checkpoint
from redis import Redis
from redisvl.query.filter import Tag

from langgraph.checkpoint.redis import RedisSaver


@pytest.fixture(autouse=True)
async def clear_test_redis(redis_url: str) -> None:
"""Clear Redis before each test."""
client = Redis.from_url(redis_url)
try:
client.flushall()
finally:
client.close()


def test_numeric_version_issue(redis_url: str) -> None:
"""
Test reproduction for issue #40.

This test explicitly creates a scenario where a numeric version field
is compared with a Tag type, which should cause a TypeError.
"""
# Create a Redis saver with default configuration
saver = RedisSaver(redis_url)
saver.setup()

try:
# Here we'll directly test the specific problem from issue #40
# In the real app, the version field is stored as a number in Redis
# Then when the code in _dump_blobs tries to pass that numeric version
# to the Tag filter, it causes a TypeError

# First create a fixed test with direct Tag usage to demonstrate the issue
tag_filter = Tag("version")

with pytest.raises(TypeError) as excinfo:
# This will trigger the error because we're comparing Tag with integer
result = tag_filter == 1 # Integer instead of string

# Verify the specific error message related to Tag comparison
assert "Right side argument passed to operator" in str(excinfo.value)
assert "Tag must be of type" in str(excinfo.value)

# Another approach would be a direct test of our _dump_blobs method
# by creating a fake numeric version and then trying to create a Tag query
# based on it
channel_name = "test_channel"
numeric_version = 1 # This is the root issue - numeric version not string

# This mimics the code in _dump_blobs that would fail
versions = {channel_name: numeric_version}

# We can't directly patch the method, but we can verify the same type issue
# Here we simulate what happens when a numeric version is passed to Tag filter
tag_filter = Tag("version")
with pytest.raises(TypeError) as excinfo2:
# This fails because we're comparing a Tag with a numeric value directly
result = tag_filter == versions[channel_name] # Numeric version

# Check the error message
assert "must be of type" in str(excinfo2.value)

finally:
# Clean up
if saver._owns_its_client:
saver._redis.close()