|  | 
|  | 1 | +# Copyright 2024 Google LLC | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#      http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import json | 
|  | 16 | +import redis | 
|  | 17 | +from typing import List, Optional | 
|  | 18 | + | 
|  | 19 | +from langchain_core.chat_history import BaseChatMessageHistory | 
|  | 20 | +from langchain_core.messages import ( | 
|  | 21 | +    BaseMessage, | 
|  | 22 | +    message_to_dict, | 
|  | 23 | +    messages_from_dict, | 
|  | 24 | +) | 
|  | 25 | + | 
|  | 26 | + | 
|  | 27 | +class MemorystoreChatMessageHistory(BaseChatMessageHistory): | 
|  | 28 | +    """Chat message history stored in a Cloud Memorystore for Redis database.""" | 
|  | 29 | + | 
|  | 30 | +    def __init__( | 
|  | 31 | +        self, | 
|  | 32 | +        client: redis.Redis, | 
|  | 33 | +        session_id: str, | 
|  | 34 | +        ttl: Optional[int] = None, | 
|  | 35 | +    ): | 
|  | 36 | +        """Initializes the chat message history for Memorystore for Redis. | 
|  | 37 | +
 | 
|  | 38 | +        Args: | 
|  | 39 | +            client: A redis.Redis client object. | 
|  | 40 | +            session_id: A string that uniquely identifies the chat history. | 
|  | 41 | +            ttl: Specifies the time in seconds after which the session will | 
|  | 42 | +                expire and be eliminated from the Redis instance since the most | 
|  | 43 | +                recent message is added. | 
|  | 44 | +        """ | 
|  | 45 | + | 
|  | 46 | +        self._redis = client | 
|  | 47 | +        self._key = session_id | 
|  | 48 | +        self._ttl = ttl | 
|  | 49 | + | 
|  | 50 | +    @property | 
|  | 51 | +    def messages(self) -> List[BaseMessage]: | 
|  | 52 | +        """Retrieve all messages chronologically stored in this session.""" | 
|  | 53 | +        all_elements = self._redis.lrange(self._key, 0, -1) | 
|  | 54 | +        messages = messages_from_dict( | 
|  | 55 | +            [json.loads(e.decode("utf-8")) for e in all_elements] | 
|  | 56 | +        ) | 
|  | 57 | +        return messages | 
|  | 58 | + | 
|  | 59 | +    def add_message(self, message: BaseMessage) -> None: | 
|  | 60 | +        """Append one message to this session.""" | 
|  | 61 | +        self._redis.rpush(self._key, json.dumps(message_to_dict(message))) | 
|  | 62 | +        if self._ttl: | 
|  | 63 | +            self._redis.expire(self._key, self._ttl) | 
|  | 64 | + | 
|  | 65 | +    def clear(self) -> None: | 
|  | 66 | +        """Clear all messages in this session.""" | 
|  | 67 | +        self._redis.delete(self._key) | 
0 commit comments