Skip to content

Commit 40dcd2e

Browse files
committed
fix session not started issue
1 parent b186b5e commit 40dcd2e

File tree

4 files changed

+11
-4
lines changed

4 files changed

+11
-4
lines changed

taskweaver/ces/environment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,6 @@ def execute_code(
304304
) -> ExecutionResult:
305305
exec_id = get_id(prefix="exec") if exec_id is None else exec_id
306306
session = self._get_session(session_id)
307-
if session.kernel_status == "pending":
308-
self.start_session(session_id)
309307

310308
session.execution_count += 1
311309
execution_index = session.execution_count
@@ -438,6 +436,8 @@ def _get_session(
438436
)
439437
os.makedirs(new_session.session_dir, exist_ok=True)
440438
self.session_dict[session_id] = new_session
439+
elif session_id not in self.session_dict:
440+
raise ValueError(f"Session {session_id} not found.")
441441

442442
return self.session_dict.get(session_id, None)
443443

taskweaver/code_interpreter/code_executor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
self.plugin_loaded: bool = False
6464
self.config = config
6565
self.tracing = tracing
66+
self.session_variables = {}
6667

6768
@tracing_decorator
6869
def execute_code(self, exec_id: str, code: str) -> ExecutionResult:
@@ -75,6 +76,9 @@ def execute_code(self, exec_id: str, code: str) -> ExecutionResult:
7576
with get_tracer().start_as_current_span("load_plugin"):
7677
self.load_plugin()
7778
self.plugin_loaded = True
79+
80+
# update session variables
81+
self.exec_client.update_session_var(self.session_variables)
7882

7983
with get_tracer().start_as_current_span("run_code"):
8084
self.tracing.set_span_attribute("code", code)
@@ -104,6 +108,9 @@ def execute_code(self, exec_id: str, code: str) -> ExecutionResult:
104108
self.tracing.set_span_attribute("result", self.format_code_output(result, with_code=False))
105109

106110
return result
111+
112+
def update_session_var(self, session_var_dict: dict) -> None:
113+
self.session_variables.update(session_var_dict)
107114

108115
def _save_file(
109116
self,

taskweaver/code_interpreter/code_interpreter/code_interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def get_intro(self) -> str:
122122

123123
def update_session_variables(self, session_variables: Dict[str, str]):
124124
self.logger.info(f"Updating session variables: {session_variables}")
125-
self.executor.exec_client.update_session_var(session_variables)
125+
self.executor.update_session_var(session_variables)
126126

127127
@tracing_decorator
128128
def reply(

taskweaver/code_interpreter/code_interpreter_plugin_only/code_interpreter_plugin_only.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_intro(self) -> str:
5757
return self.intro.format(plugin_description=self.plugin_description)
5858

5959
def update_session_variables(self, session_variables: dict) -> None:
60-
self.executor.exec_client.update_session_variables(session_variables)
60+
self.executor.update_session_var(session_variables)
6161

6262
@tracing_decorator
6363
def reply(

0 commit comments

Comments
 (0)