-
Notifications
You must be signed in to change notification settings - Fork 11.9k
/
Copy pathvariable_pool.py
162 lines (125 loc) · 5.01 KB
/
variable_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool(BaseModel):
# Variable dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
description='Variables mapping',
default=defaultdict(dict)
)
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description='User inputs',
)
system_variables: Mapping[SystemVariableKey, Any] = Field(
description='System variables',
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list
)
conversation_variables: Sequence[Variable] | None = None
@model_validator(mode="after")
def val_model_after(self):
"""
Append system variables
:return:
"""
# Add system variables to the variable pool
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in self.environment_variables or []:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in self.conversation_variables or []:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
Args:
selector (Sequence[str]): The selector for the variable.
value (VariableValue): The value of the variable.
Raises:
ValueError: If the selector is invalid.
Returns:
None
"""
if len(selector) < 2:
raise ValueError("Invalid selector")
if value is None:
return
if isinstance(value, Segment):
v = value
else:
v = factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
Retrieves the value from the variable pool based on the given selector.
Args:
selector (Sequence[str]): The selector used to identify the variable.
Returns:
Any: The value associated with the given selector.
Raises:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
return value
@deprecated("This method is deprecated, use `get` instead.")
def get_any(self, selector: Sequence[str], /) -> Any | None:
"""
Retrieves the value from the variable pool based on the given selector.
Args:
selector (Sequence[str]): The selector used to identify the variable.
Returns:
Any: The value associated with the given selector.
Raises:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None
def remove(self, selector: Sequence[str], /):
"""
Remove variables from the variable pool based on the given selector.
Args:
selector (Sequence[str]): A sequence of strings representing the selector.
Returns:
None
"""
if not selector:
return
if len(selector) == 1:
self.variable_dictionary[selector[0]] = {}
return
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None)
def remove_node(self, node_id: str, /):
"""
Remove all variables associated with a given node id.
Args:
node_id (str): The node id to remove.
Returns:
None
"""
self.variable_dictionary.pop(node_id, None)