Skip to content

Commit f27a9cf

Browse files
GWealecopybara-github
authored andcommitted
fix: Expand add_memory to accept MemoryEntry
The `add_memory` methods in `Context` and `BaseMemoryService` now accept `MemoryEntry` objects in addition to strings. The Vertex AI Memory Bank service implementation is updated to handle these new types Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 872108561
1 parent 2d8b6a2 commit f27a9cf

File tree

6 files changed

+479
-60
lines changed

6 files changed

+479
-60
lines changed

src/google/adk/agents/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..events.event import Event
3333
from ..events.event_actions import EventActions
3434
from ..memory.base_memory_service import SearchMemoryResponse
35+
from ..memory.memory_entry import MemoryEntry
3536
from ..sessions.state import State
3637
from ..tools.tool_confirmation import ToolConfirmation
3738
from .invocation_context import InvocationContext
@@ -366,7 +367,7 @@ async def add_events_to_memory(
366367
async def add_memory(
367368
self,
368369
*,
369-
memories: Sequence[str],
370+
memories: Sequence[MemoryEntry],
370371
custom_metadata: Mapping[str, object] | None = None,
371372
) -> None:
372373
"""Adds explicit memory items directly to the memory service.

src/google/adk/memory/base_memory_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def add_memory(
9999
*,
100100
app_name: str,
101101
user_id: str,
102-
memories: Sequence[str],
102+
memories: Sequence[MemoryEntry],
103103
custom_metadata: Mapping[str, object] | None = None,
104104
) -> None:
105105
"""Adds explicit memory items directly to the memory service.

src/google/adk/memory/vertex_ai_memory_bank_service.py

Lines changed: 174 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
'expire_time',
5858
'http_options',
5959
'metadata',
60+
'revision_labels',
6061
'revision_expire_time',
6162
'revision_ttl',
6263
'topics',
@@ -215,7 +216,7 @@ async def add_memory(
215216
*,
216217
app_name: str,
217218
user_id: str,
218-
memories: Sequence[str],
219+
memories: Sequence[MemoryEntry],
219220
custom_metadata: Mapping[str, object] | None = None,
220221
) -> None:
221222
"""Adds explicit memory items via Vertex memories.create."""
@@ -267,20 +268,29 @@ async def _add_memories_via_create(
267268
*,
268269
app_name: str,
269270
user_id: str,
270-
memories: Sequence[str],
271+
memories: Sequence[MemoryEntry],
271272
custom_metadata: Mapping[str, object] | None = None,
272273
) -> None:
273274
"""Adds direct memory items without server-side extraction."""
274275
if not self._agent_engine_id:
275276
raise ValueError('Agent Engine ID is required for Memory Bank.')
276277

277-
memory_texts = _validate_memory_texts(memories)
278+
normalized_memories = _normalize_memories_for_create(memories)
278279
api_client = self._get_api_client()
279-
config = _build_create_memory_config(custom_metadata)
280-
for memory_text in memory_texts:
280+
for index, memory in enumerate(normalized_memories):
281+
memory_fact = _memory_entry_to_fact(memory, index=index)
282+
memory_metadata = _merge_custom_metadata_for_memory(
283+
custom_metadata=custom_metadata,
284+
memory=memory,
285+
)
286+
memory_revision_labels = _revision_labels_for_memory(memory)
287+
config = _build_create_memory_config(
288+
memory_metadata,
289+
memory_revision_labels=memory_revision_labels,
290+
)
281291
operation = await api_client.agent_engines.memories.create(
282292
name='reasoningEngines/' + self._agent_engine_id,
283-
fact=memory_text,
293+
fact=memory_fact,
284294
scope={
285295
'app_name': app_name,
286296
'user_id': user_id,
@@ -431,18 +441,21 @@ def _build_generate_memories_config(
431441

432442
def _build_create_memory_config(
433443
custom_metadata: Mapping[str, object] | None,
444+
*,
445+
memory_revision_labels: Mapping[str, str] | None = None,
434446
) -> dict[str, object]:
435447
"""Builds a valid memories.create config from caller metadata."""
436448
config: dict[str, object] = {'wait_for_completion': False}
437449
supports_metadata = _supports_create_memory_metadata()
438450
config_keys = _get_create_memory_config_keys()
439-
if not custom_metadata:
440-
return config
451+
supports_revision_labels = 'revision_labels' in config_keys
441452

442-
logger.debug('Memory creation metadata: %s', custom_metadata)
453+
if custom_metadata:
454+
logger.debug('Memory creation metadata: %s', custom_metadata)
443455

444456
metadata_by_key: dict[str, object] = {}
445-
for key, value in custom_metadata.items():
457+
custom_revision_labels: dict[str, str] = {}
458+
for key, value in (custom_metadata or {}).items():
446459
if key == 'metadata':
447460
if value is None:
448461
continue
@@ -460,63 +473,172 @@ def _build_create_memory_config(
460473
' mapping.'
461474
)
462475
continue
476+
if key == 'revision_labels':
477+
if value is None:
478+
continue
479+
extracted_labels = _extract_revision_labels(
480+
value,
481+
source='custom_metadata["revision_labels"]',
482+
)
483+
if extracted_labels:
484+
custom_revision_labels.update(extracted_labels)
485+
continue
463486
if key in config_keys:
464487
if value is None:
465488
continue
466489
config[key] = value
467490
else:
468491
metadata_by_key[key] = value
469492

470-
if not metadata_by_key:
471-
return config
472-
473-
if not supports_metadata:
474-
logger.warning(
475-
'Ignoring custom metadata keys %s because installed Vertex SDK does '
476-
'not support create config.metadata.',
477-
sorted(metadata_by_key.keys()),
478-
)
479-
return config
480-
481-
existing_metadata = config.get('metadata')
482-
if existing_metadata is None:
483-
config['metadata'] = _build_vertex_metadata(metadata_by_key)
484-
return config
485-
486-
if isinstance(existing_metadata, Mapping):
487-
merged_metadata = dict(existing_metadata)
488-
merged_metadata.update(_build_vertex_metadata(metadata_by_key))
489-
config['metadata'] = merged_metadata
490-
return config
493+
if metadata_by_key:
494+
if not supports_metadata:
495+
logger.warning(
496+
'Ignoring custom metadata keys %s because installed Vertex SDK does '
497+
'not support create config.metadata.',
498+
sorted(metadata_by_key.keys()),
499+
)
500+
else:
501+
existing_metadata = config.get('metadata')
502+
if existing_metadata is None:
503+
config['metadata'] = _build_vertex_metadata(metadata_by_key)
504+
elif isinstance(existing_metadata, Mapping):
505+
merged_metadata = dict(existing_metadata)
506+
merged_metadata.update(_build_vertex_metadata(metadata_by_key))
507+
config['metadata'] = merged_metadata
508+
else:
509+
logger.warning(
510+
'Ignoring custom metadata keys %s because config.metadata is not a'
511+
' mapping.',
512+
sorted(metadata_by_key.keys()),
513+
)
491514

492-
logger.warning(
493-
'Ignoring custom metadata keys %s because config.metadata is not a'
494-
' mapping.',
495-
sorted(metadata_by_key.keys()),
496-
)
515+
revision_labels = dict(custom_revision_labels)
516+
if memory_revision_labels:
517+
revision_labels.update(memory_revision_labels)
518+
if revision_labels:
519+
if supports_revision_labels:
520+
config['revision_labels'] = revision_labels
521+
else:
522+
logger.warning(
523+
'Ignoring revision labels %s because installed Vertex SDK does not '
524+
'support create config.revision_labels.',
525+
sorted(revision_labels.keys()),
526+
)
497527
return config
498528

499529

500-
def _validate_memory_texts(
501-
memories: Sequence[str],
502-
) -> list[str]:
503-
"""Validates direct textual memory items passed to add_memory."""
530+
def _normalize_memories_for_create(
531+
memories: Sequence[MemoryEntry],
532+
) -> list[MemoryEntry]:
533+
"""Validates add_memory inputs."""
504534
if isinstance(memories, str):
505-
raise TypeError('memories must be a sequence of strings.')
535+
raise TypeError('memories must be a sequence of memory items.')
506536
if not isinstance(memories, Sequence):
507-
raise TypeError('memories must be a sequence of strings.')
508-
memory_texts: list[str] = []
537+
raise TypeError('memories must be a sequence of memory items.')
538+
539+
validated_memories: list[MemoryEntry] = []
509540
for index, raw_memory in enumerate(memories):
510-
if not isinstance(raw_memory, str):
511-
raise TypeError(f'memories[{index}] must be a string.')
512-
memory_text = raw_memory.strip()
513-
if not memory_text:
514-
raise ValueError(f'memories[{index}] must not be empty.')
515-
memory_texts.append(memory_text)
516-
517-
if not memory_texts:
541+
if not isinstance(raw_memory, MemoryEntry):
542+
raise TypeError(f'memories[{index}] must be a MemoryEntry.')
543+
validated_memories.append(raw_memory)
544+
if not validated_memories:
518545
raise ValueError('memories must contain at least one entry.')
519-
return memory_texts
546+
return validated_memories
547+
548+
549+
def _memory_entry_to_fact(
550+
memory: MemoryEntry,
551+
*,
552+
index: int,
553+
) -> str:
554+
"""Builds a memories.create fact payload from MemoryEntry text content."""
555+
if _should_filter_out_event(memory.content):
556+
raise ValueError(f'memories[{index}] must include text.')
557+
558+
text_parts: list[str] = []
559+
for part in memory.content.parts:
560+
if part.inline_data or part.file_data:
561+
raise ValueError(
562+
f'memories[{index}] must include text only; inline_data and '
563+
'file_data are not supported.'
564+
)
565+
566+
if not part.text:
567+
continue
568+
stripped_text = part.text.strip()
569+
if stripped_text:
570+
text_parts.append(stripped_text)
571+
572+
if not text_parts:
573+
raise ValueError(f'memories[{index}] must include non-whitespace text.')
574+
return '\n'.join(text_parts)
575+
576+
577+
def _merge_custom_metadata_for_memory(
578+
*,
579+
custom_metadata: Mapping[str, object] | None,
580+
memory: MemoryEntry,
581+
) -> Mapping[str, object] | None:
582+
"""Merges write-level metadata with MemoryEntry metadata."""
583+
merged_metadata: dict[str, object] = {}
584+
585+
if custom_metadata:
586+
merged_metadata.update(dict(custom_metadata))
587+
if memory.custom_metadata:
588+
merged_metadata.update(memory.custom_metadata)
589+
590+
if not merged_metadata:
591+
return None
592+
return merged_metadata
593+
594+
595+
def _revision_labels_for_memory(
596+
memory: MemoryEntry,
597+
) -> Mapping[str, str] | None:
598+
"""Builds revision labels from MemoryEntry revision metadata."""
599+
revision_labels: dict[str, str] = {}
600+
if memory.author is not None:
601+
revision_labels['author'] = memory.author
602+
if memory.timestamp is not None:
603+
revision_labels['timestamp'] = memory.timestamp
604+
605+
if not revision_labels:
606+
return None
607+
return revision_labels
608+
609+
610+
def _extract_revision_labels(
611+
value: object,
612+
*,
613+
source: str,
614+
) -> Mapping[str, str] | None:
615+
"""Extracts revision labels from config metadata."""
616+
if not isinstance(value, Mapping):
617+
logger.warning('Ignoring %s because it is not a mapping.', source)
618+
return None
619+
620+
revision_labels: dict[str, str] = {}
621+
for key, label_value in value.items():
622+
if not isinstance(key, str):
623+
logger.warning(
624+
'Ignoring revision label with non-string key %r from %s.',
625+
key,
626+
source,
627+
)
628+
continue
629+
if not isinstance(label_value, str):
630+
logger.warning(
631+
'Ignoring revision label %s from %s because its value is not a '
632+
'string.',
633+
key,
634+
source,
635+
)
636+
continue
637+
revision_labels[key] = label_value
638+
639+
if not revision_labels:
640+
return None
641+
return revision_labels
520642

521643

522644
def _build_vertex_metadata(

tests/unittests/agents/test_callback_context.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from google.adk.auth.auth_credential import AuthCredential
2323
from google.adk.auth.auth_credential import AuthCredentialTypes
2424
from google.adk.auth.auth_tool import AuthConfig
25+
from google.adk.memory.memory_entry import MemoryEntry
2526
from google.adk.tools.tool_context import ToolContext
27+
from google.genai import types
2628
from google.genai.types import Part
2729
import pytest
2830

@@ -417,7 +419,9 @@ async def test_add_memory_forwards_metadata(self, mock_invocation_context):
417419
"""Tests that add_memory forwards memories and metadata."""
418420
memory_service = AsyncMock()
419421
mock_invocation_context.memory_service = memory_service
420-
memories = ["fact one"]
422+
memories = [
423+
MemoryEntry(content=types.Content(parts=[types.Part(text="fact one")]))
424+
]
421425
metadata = {"ttl": "6000s"}
422426

423427
context = CallbackContext(mock_invocation_context)
@@ -430,6 +434,27 @@ async def test_add_memory_forwards_metadata(self, mock_invocation_context):
430434
custom_metadata=metadata,
431435
)
432436

437+
@pytest.mark.asyncio
438+
async def test_add_memory_accepts_memory_entries(
439+
self, mock_invocation_context
440+
):
441+
"""Tests that add_memory forwards MemoryEntry inputs unchanged."""
442+
memory_service = AsyncMock()
443+
mock_invocation_context.memory_service = memory_service
444+
memory_entry = MemoryEntry(
445+
content=types.Content(parts=[types.Part(text="fact one")])
446+
)
447+
448+
context = CallbackContext(mock_invocation_context)
449+
await context.add_memory(memories=[memory_entry])
450+
451+
memory_service.add_memory.assert_called_once_with(
452+
app_name=mock_invocation_context.session.app_name,
453+
user_id=mock_invocation_context.session.user_id,
454+
memories=[memory_entry],
455+
custom_metadata=None,
456+
)
457+
433458
@pytest.mark.asyncio
434459
async def test_add_memory_no_service_raises(self, mock_invocation_context):
435460
"""Tests that add_memory raises ValueError with no service."""
@@ -441,7 +466,13 @@ async def test_add_memory_no_service_raises(self, mock_invocation_context):
441466
ValueError,
442467
match=r"Cannot add memory: memory service is not available\.",
443468
):
444-
await context.add_memory(memories=["fact one"])
469+
await context.add_memory(
470+
memories=[
471+
MemoryEntry(
472+
content=types.Content(parts=[types.Part(text="fact one")])
473+
)
474+
]
475+
)
445476

446477

447478
class TestToolContextAddSessionToMemory:

0 commit comments

Comments
 (0)