Skip to content

Commit b91b63c

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Initial implementation of the SDK for Memory Revisions
PiperOrigin-RevId: 818856682
1 parent ac6e0b4 commit b91b63c

File tree

4 files changed

+984
-18
lines changed

4 files changed

+984
-18
lines changed

tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,19 @@
1919
from google.genai import types as genai_types
2020

2121

22-
def test_generate_memories(client):
22+
def test_generate_and_rollback_memories(client):
23+
client._api_client._http_options.base_url = (
24+
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com/"
25+
)
2326
agent_engine = client.agent_engines.create()
2427
assert not list(
25-
client.agent_engines.list_memories(
28+
client.agent_engines.memories.list(
2629
name=agent_engine.api_resource.name,
2730
)
2831
)
29-
client.agent_engines.generate_memories(
32+
# Generate memories using source content. This result is non-deterministic,
33+
# because an LLM is used to generate the memories.
34+
client.agent_engines.memories.generate(
3035
name=agent_engine.api_resource.name,
3136
scope={"user_id": "test-user-id"},
3237
direct_contents_source=types.GenerateMemoriesRequestDirectContentsSource(
@@ -43,23 +48,72 @@ def test_generate_memories(client):
4348
)
4449
]
4550
),
51+
config=types.GenerateAgentEngineMemoriesConfig(
52+
revision_labels={"key": "value"}
53+
),
4654
)
47-
assert (
48-
len(
49-
list(
50-
client.agent_engines.list_memories(
51-
name=agent_engine.api_resource.name,
52-
)
53-
)
55+
memories = list(
56+
client.agent_engines.memories.list(
57+
name=agent_engine.api_resource.name,
5458
)
55-
>= 1
5659
)
60+
assert len(memories) >= 1
61+
62+
# Every action that modifies a memory creates a new revision.
63+
memory_revisions = list(
64+
client.agent_engines.memories.revisions.list(
65+
name=memories[0].name,
66+
)
67+
)
68+
assert len(memory_revisions) >= 1
69+
# The revision's labels depend on the generation request's revision labels.
70+
assert memory_revisions[0].labels == {"key": "value"}
71+
revision_name = memory_revisions[0].name
72+
73+
# Update the memory.
74+
client.agent_engines.memories._update(
75+
name=memories[0].name,
76+
fact="This is temporary",
77+
scope={"user_id": "test-user-id"},
78+
)
79+
memory = client.agent_engines.memories.get(name=memories[0].name)
80+
assert memory.fact == "This is temporary"
81+
82+
# Rollback to the revision with the original fact that was created by the
83+
# generation request.
84+
client.agent_engines.memories.rollback(
85+
name=memories[0].name,
86+
target_revision_id=revision_name.split("/")[-1],
87+
)
88+
memory = client.agent_engines.memories.get(name=memories[0].name)
89+
assert memory.fact == memory_revisions[0].fact
90+
91+
# Update the memory again using generation. We use the original source
92+
# content to ensure that the original memory is updated. The response should
93+
# refer to the previous revision.
94+
response = client.agent_engines.memories.generate(
95+
name=agent_engine.api_resource.name,
96+
scope={"user_id": "test-user-id"},
97+
direct_contents_source=types.GenerateMemoriesRequestDirectContentsSource(
98+
events=[
99+
types.GenerateMemoriesRequestDirectContentsSourceEvent(
100+
content=genai_types.Content(
101+
role="model",
102+
parts=[genai_types.Part(text=memory_revisions[0].fact)],
103+
)
104+
)
105+
]
106+
),
107+
)
108+
# The memory was updated, so the previous revision is set.
109+
assert response.response.generated_memories[0].previous_revision is not None
110+
57111
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
58112

59113

60114
def test_generate_memories_direct_memories_source(client):
61115
agent_engine = client.agent_engines.create()
62-
client.agent_engines.generate_memories(
116+
client.agent_engines.memories.generate(
63117
name=agent_engine.api_resource.name,
64118
scope={"user_id": "test-user-id"},
65119
direct_memories_source=types.GenerateMemoriesRequestDirectMemoriesSource(
@@ -77,7 +131,7 @@ def test_generate_memories_direct_memories_source(client):
77131
assert (
78132
len(
79133
list(
80-
client.agent_engines.list_memories(
134+
client.agent_engines.memories.list(
81135
name=agent_engine.api_resource.name,
82136
)
83137
)

0 commit comments

Comments
 (0)