Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions src/strands_tools/code_interpreter/agent_core_code_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import base64
import json
import logging
import os
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional

from bedrock_agentcore.tools.code_interpreter_client import CodeInterpreter as BedrockAgentCoreCodeInterpreterClient

from ..utils.aws_util import resolve_region
from .code_interpreter import CodeInterpreter
from .models import (
DownloadFilesAction,
ExecuteCodeAction,
ExecuteCommandAction,
InitSessionAction,
Expand Down Expand Up @@ -476,6 +481,209 @@ def write_files(self, action: WriteFilesAction) -> Dict[str, Any]:

return self._create_tool_result(response)

def download_files(self, action: DownloadFilesAction) -> Dict[str, Any]:
"""
Download files from the Code Interpreter sandbox to the local file system.

This method downloads files by:
1. Executing Python code in the sandbox to read and base64-encode the files
2. Retrieving the base64-encoded content from the execution result
3. Decoding and saving the files to the specified local destination directory

Args:
action (DownloadFilesAction): Action containing source paths and destination directory

Returns:
Dict[str, Any]: Response dictionary containing download results or error details.
Success response includes downloadedFiles list with local file paths.

Raises:
Exception: If session management fails, file encoding/decoding fails, or
file system operations encounter errors.
"""
session_name, error = self._ensure_session(action.session_name)
if error:
return error

logger.debug(f"Downloading {len(action.source_paths)} files from session '{session_name}'")

try:
# Validate destination directory and create if it doesn't exist
dest_path = Path(action.destination_dir)
if not dest_path.is_absolute():
return {
"status": "error",
"content": [{"text": f"Destination directory must be an absolute path: {action.destination_dir}"}],
}

# Create destination directory if it doesn't exist
dest_path.mkdir(parents=True, exist_ok=True)

# Generate Python code to read and base64-encode the files in the sandbox
source_paths_json = json.dumps(action.source_paths)
encode_code = f"""
import base64
import json
import os

results = {{}}
source_paths = {source_paths_json}

for path in source_paths:
try:
if not os.path.exists(path):
results[path] = {{"error": f"File not found: {{path}}"}}
continue

with open(path, 'rb') as f:
file_data = f.read()
results[path] = {{
"data": base64.b64encode(file_data).decode('utf-8'),
"size": len(file_data)
}}
except Exception as e:
results[path] = {{"error": f"Failed to read file {{path}}: {{str(e)}}"}}

print("__DOWNLOAD_RESULTS__")
print(json.dumps(results))
print("__DOWNLOAD_RESULTS_END__")
"""

# Execute the encoding code in the sandbox
params = {"code": encode_code, "language": "python", "clearContext": False}
response = self._sessions[session_name].client.invoke("executeCode", params)

# Extract the execution result
execution_result = self._create_tool_result(response)
if execution_result.get("status") != "success":
return {
"status": "error",
"content": [{"text": f"Failed to execute file encoding in sandbox: {execution_result}"}],
}

# Parse the base64-encoded results from the output
content = execution_result["content"][0]
if isinstance(content, dict) and "text" in content:
output_text = content["text"]
else:
output_text = str(content)

# Handle case where output_text might be a list representation
if output_text.startswith("[{") and "text" in output_text:
import re

# Extract text from list format: [{'type': 'text', 'text': '...'}]
match = re.search(r"'text':\s*'([^']*(?:\\'[^']*)*)'", output_text)
if match:
# Unescape the captured text
output_text = match.group(1).replace("\\'", "'").replace("\\n", "\n").replace("\\\\", "\\")
else:
logger.warning(f"Could not extract text from list format: {output_text[:200]}...")

logger.debug(f"Extracted text: {output_text[:200]}...")

# Extract JSON results between markers
start_marker = "__DOWNLOAD_RESULTS__"
end_marker = "__DOWNLOAD_RESULTS_END__"

start_idx = output_text.find(start_marker)
end_idx = output_text.find(end_marker)

if start_idx == -1 or end_idx == -1:
return {
"status": "error",
"content": [
{
"text": f"Could not find download results in output. "
f"Start marker found: {start_idx >= 0}, End marker found: {end_idx >= 0}. "
f"Output: {output_text[:1000]}..."
}
],
}

json_start = start_idx + len(start_marker)
results_json = output_text[json_start:end_idx].strip()
logger.debug(f"Extracted JSON: '{results_json}'")

if not results_json:
return {
"status": "error",
"content": [{"text": f"Empty JSON results between markers. Full output: {output_text}"}],
}

try:
file_results = json.loads(results_json)
except json.JSONDecodeError as e:
return {
"status": "error",
"content": [
{
"text": f"Failed to parse download results JSON: {e}. "
f"JSON string: '{results_json}'. Full output: {output_text}"
}
],
}

# Process each file result
downloaded_files = []
errors = []

for source_path, result in file_results.items():
if "error" in result:
errors.append(f"{source_path}: {result['error']}")
continue

try:
# Decode base64 data
file_data = base64.b64decode(result["data"])

# Determine local file path
source_filename = os.path.basename(source_path)
local_path = dest_path / source_filename

# Handle filename conflicts by adding a counter
counter = 1
base_name = source_filename
while local_path.exists():
if "." in base_name:
name, ext = base_name.rsplit(".", 1)
local_path = dest_path / f"{name}_{counter}.{ext}"
else:
local_path = dest_path / f"{base_name}_{counter}"
counter += 1

# Write file to local filesystem
with open(local_path, "wb") as f:
f.write(file_data)

downloaded_files.append(
{"sourcePath": source_path, "localPath": str(local_path), "size": result["size"]}
)

logger.info(f"Downloaded file: {source_path} -> {local_path} ({result['size']} bytes)")

except Exception as e:
errors.append(f"{source_path}: Failed to decode/save file: {str(e)}")

# Prepare response
if errors and not downloaded_files:
return {"status": "error", "content": [{"text": f"All downloads failed: {'; '.join(errors)}"}]}

response_data = {
"downloadedFiles": downloaded_files,
"totalFiles": len(downloaded_files),
"destinationDir": str(dest_path),
}

if errors:
response_data["errors"] = errors

return {"status": "success", "content": [{"json": response_data}]}

except Exception as e:
logger.error(f"Failed to download files from session '{session_name}': {str(e)}")
return {"status": "error", "content": [{"text": f"Failed to download files: {str(e)}"}]}

def _create_tool_result(self, response) -> Dict[str, Any]:
"""Create tool result from response."""
if "stream" in response:
Expand Down
11 changes: 11 additions & 0 deletions src/strands_tools/code_interpreter/code_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .models import (
CodeInterpreterInput,
DownloadFilesAction,
ExecuteCodeAction,
ExecuteCommandAction,
InitSessionAction,
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(self):
- writeFiles: Create or update files in the sandbox
- listFiles: Browse directory contents and file structures
- removeFiles: Delete files from the sandbox environment
- downloadFiles: Download files from sandbox to local filesystem

Common Usage Scenarios:
---------------------
Expand Down Expand Up @@ -167,6 +169,8 @@ def __init__(self):
- WriteFilesAction: type="writeFiles", session_name, content (list of FileContent objects)
- ListFilesAction: type="listFiles", session_name, path
- RemoveFilesAction: type="removeFiles", session_name, paths (list)
- DownloadFilesAction: type="downloadFiles", session_name, source_paths (list),
destination_dir (optional, defaults to /tmp)
- ListLocalSessionsAction: type="listLocalSessions"

Returns:
Expand Down Expand Up @@ -250,6 +254,8 @@ def code_interpreter(self, code_interpreter_input: CodeInterpreterInput) -> Dict
return self.remove_files(action)
elif isinstance(action, WriteFilesAction):
return self.write_files(action)
elif isinstance(action, DownloadFilesAction):
return self.download_files(action)
else:
return {"status": "error", "content": [{"text": f"Unknown action: {type(action)}"}]}

Expand Down Expand Up @@ -322,6 +328,11 @@ def write_files(self, action: WriteFilesAction) -> Dict[str, Any]:
"""Write files to a sandbox session."""
...

@abstractmethod
def download_files(self, action: DownloadFilesAction) -> Dict[str, Any]:
"""Download files from a sandbox session to the local filesystem."""
...

@abstractmethod
def list_local_sessions(self) -> Dict[str, Any]:
"""List all sessions created by this platform instance."""
Expand Down
18 changes: 18 additions & 0 deletions src/strands_tools/code_interpreter/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,23 @@ class WriteFilesAction(BaseModel):
content: List[FileContent] = Field(description="Required list of file content to write")


class DownloadFilesAction(BaseModel):
"""Download files from the Code Interpreter sandbox to the local file system. Use this to retrieve generated
files (CSV, Excel, images, etc.) from the session after data analysis or file processing. The files are
downloaded as binary data and saved to the specified local directory."""

type: Literal["downloadFiles"] = Field(description="Download files from the code interpreter to local filesystem")

session_name: Optional[str] = Field(
default=None, description="Session name. If not provided, uses the default session."
)

source_paths: List[str] = Field(description="Required list of file paths in the sandbox to download")
destination_dir: str = Field(
default="/tmp", description="Local directory to save downloaded files (defaults to /tmp)"
)


class CodeInterpreterInput(BaseModel):
action: Union[
InitSessionAction,
Expand All @@ -139,4 +156,5 @@ class CodeInterpreterInput(BaseModel):
ListFilesAction,
RemoveFilesAction,
WriteFilesAction,
DownloadFilesAction,
] = Field(discriminator="type")
4 changes: 2 additions & 2 deletions src/strands_tools/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@
"lessThanOrEquals, in (value in list), notIn, listContains (list contains value), "
"stringContains (substring match), startsWith (OpenSearch Serverless only), "
"andAll (all conditions must match, min 2 items), orAll (at least one condition must match, "
"min 2 items). Example: {\"andAll\": [{\"equals\": {\"key\": \"category\", "
"\"value\": \"security\"}}, {\"greaterThan\": {\"key\": \"year\", \"value\": \"2022\"}}]}"
'min 2 items). Example: {"andAll": [{"equals": {"key": "category", '
'"value": "security"}}, {"greaterThan": {"key": "year", "value": "2022"}}]}'
),
},
},
Expand Down
Loading
Loading