Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 119 additions & 47 deletions dataflow/operators/code/eval/python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from multiprocessing import Queue, Process
from typing import Any, Dict, Optional, Tuple, List, Union
from tqdm import tqdm
from concurrent.futures import TimeoutError
from concurrent.futures import TimeoutError, ThreadPoolExecutor, Future
from contextlib import redirect_stdout
import base64
from io import BytesIO
from PIL import Image
import threading
try:
import matplotlib
matplotlib.use('Agg')
Expand Down Expand Up @@ -129,6 +130,65 @@ def start(self):
self.process.daemon = True
self.process.start()

def _execute_code_with_timeout(self, runtime, code, get_answer_from_stdout, answer_symbol, answer_expr, timeout):
"""Execute code with timeout protection using thread pool."""
def _run_code():
# Record the number of images before execution
pre_figures_count = len(runtime._global_vars.get("_captured_figures", []))

if get_answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
runtime.exec_code("\n".join(code))
program_io.seek(0)
result = program_io.read()
elif answer_symbol:
runtime.exec_code("\n".join(code))
result = runtime._global_vars.get(answer_symbol, "")
elif answer_expr:
runtime.exec_code("\n".join(code))
result = runtime.eval_code(answer_expr)
else:
if len(code) > 1:
runtime.exec_code("\n".join(code[:-1]))
result = runtime.eval_code(code[-1])
else:
runtime.exec_code("\n".join(code))
result = ""

# Get newly generated images
all_figures = runtime._global_vars.get("_captured_figures", [])
new_figures = all_figures[pre_figures_count:]

# Prevent unbounded growth when the runtime is reused across multiple executions.
# We keep the pre-existing figures (e.g., injected/initial figures) and discard
# the newly captured ones after returning them in this response.
if new_figures and isinstance(all_figures, list):
runtime._global_vars["_captured_figures"] = all_figures[:pre_figures_count]

# Build result
if new_figures:
result = {
'text': result,
'images': new_figures
} if result else {'images': new_figures}
else:
result = {'text': result} if result else {}

return result

# Use thread pool executor with timeout
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(_run_code)
try:
result = future.result(timeout=timeout)
return result, None
except TimeoutError:
# Cancel the future (though it may still be running)
future.cancel()
# Note: The thread may still be running, but we've returned control
raise TimeoutError(f"Code execution exceeded timeout of {timeout} seconds")

def _worker_loop(self):
"""Main loop for the worker process."""
runtime = None
Expand Down Expand Up @@ -167,50 +227,28 @@ def _worker_loop(self):
get_answer_from_stdout = task.get('get_answer_from_stdout', True)
answer_symbol = task.get('answer_symbol')
answer_expr = task.get('answer_expr')
timeout = task.get('timeout', 30)

try:
# Record the number of images before execution
pre_figures_count = len(runtime._global_vars.get("_captured_figures", []))

if get_answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
runtime.exec_code("\n".join(code))
program_io.seek(0)
result = program_io.read()
elif answer_symbol:
runtime.exec_code("\n".join(code))
result = runtime._global_vars.get(answer_symbol, "")
elif answer_expr:
runtime.exec_code("\n".join(code))
result = runtime.eval_code(answer_expr)
else:
if len(code) > 1:
runtime.exec_code("\n".join(code[:-1]))
result = runtime.eval_code(code[-1])
else:
runtime.exec_code("\n".join(code))
result = ""

# Get newly generated images
all_figures = runtime._global_vars.get("_captured_figures", [])
new_figures = all_figures[pre_figures_count:]

# Build result
if new_figures:
result = {
'text': result,
'images': new_figures
} if result else {'images': new_figures}
else:
result = {'text': result} if result else {}
# Execute with timeout protection
result, _ = self._execute_code_with_timeout(
runtime, code, get_answer_from_stdout, answer_symbol, answer_expr, timeout
)

self.output_queue.put({
'status': 'success',
'result': result,
'report': 'Done'
})

except TimeoutError as e:
# Timeout occurred - return error but continue processing
self.output_queue.put({
'status': 'error',
'error': str(e),
'traceback': traceback.format_exc(),
'report': 'Timeout Error'
})
except Exception as e:
self.output_queue.put({
'status': 'error',
Expand Down Expand Up @@ -249,16 +287,18 @@ def execute(self, code: List[str], messages: list = None, runtime_class=None,
'runtime_class': runtime_identifier,
'get_answer_from_stdout': get_answer_from_stdout,
'answer_symbol': answer_symbol,
'answer_expr': answer_expr
'answer_expr': answer_expr,
'timeout': timeout
})

try:
result = self.output_queue.get(timeout=timeout)
# Add extra buffer time for queue communication
result = self.output_queue.get(timeout=timeout + 5)
return result
except queue.Empty:
return {
'status': 'error',
'error': 'Execution timeout',
'error': 'Execution timeout (queue timeout)',
'report': 'Timeout Error'
}

Expand Down Expand Up @@ -307,10 +347,6 @@ def __init__(self):
self.exec_code(c)

def exec_code(self, code_piece: str) -> None:
# Security check
if regex.search(r"(\s|^)?(input|os\.system|subprocess)\(", code_piece):
raise RuntimeError("Forbidden function calls detected")

# Detect and modify plt.show() calls
if "plt.show()" in code_piece and MATPLOTLIB_AVAILABLE:
modified_code = code_piece.replace("plt.show()", """
Expand Down Expand Up @@ -426,9 +462,25 @@ def __init__(
self.persistent_worker = None

def _ensure_worker(self):
"""Ensure the worker process exists."""
"""Ensure the worker process exists and is healthy."""
if self.persistent_worker is None:
self.persistent_worker = PersistentWorker()
elif self.persistent_worker.process is not None and not self.persistent_worker.process.is_alive():
# Worker process died, restart it
try:
self.persistent_worker.terminate()
except:
pass
self.persistent_worker = PersistentWorker()

def _restart_worker(self):
"""Force restart the worker process."""
if self.persistent_worker is not None:
try:
self.persistent_worker.terminate()
except:
pass
self.persistent_worker = PersistentWorker()

def process_generation_to_code(self, gens: str):
return [g.split("\n") for g in gens]
Expand Down Expand Up @@ -531,17 +583,23 @@ def truncate(s, max_length=400):
return s

def batch_apply(self, batch_code, messages):
"""
Execute a batch of code snippets.

Args:
batch_code: List of code strings to execute
messages: Context messages for execution
"""
all_code_snippets = self.process_generation_to_code(batch_code)

timeout_cnt = 0
all_exec_results = []

if len(all_code_snippets) > 100:
progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
else:
progress_bar = None

for code in all_code_snippets:
for idx, code in enumerate(all_code_snippets):
try:
result = self.execute(
code,
Expand All @@ -551,14 +609,28 @@ def batch_apply(self, batch_code, messages):
answer_symbol=self.answer_symbol,
answer_expr=self.answer_expr,
)
self._restart_worker()

all_exec_results.append(result)
except TimeoutError as error:
print(error)
all_exec_results.append(("", "Timeout Error"))
timeout_cnt += 1
self._restart_worker()

# If timeout occurs, try resetting runtime to recover
if self.use_process_isolation and self.persistent_worker:
try:
self.reset(messages)
except Exception:
# If reset fails, try restarting worker
try:
self._restart_worker()
except:
pass
except Exception as error:
print(f"Error in batch_apply: {error}")
all_exec_results.append(("", f"Error: {str(error)}"))
self._restart_worker()

if progress_bar is not None:
progress_bar.update(1)
Expand Down