Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 156 additions & 42 deletions truss/cli/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from rich.text import Text

from truss.cli.train import common, deploy_checkpoints
from truss.cli.train.file_visualizer import (
FileInfo,
FileTreeVisualizer,
VisualizationConfig,
VisualizationMetadata,
)
from truss.cli.train.metrics_watcher import MetricsWatcher
from truss.cli.train.types import (
DeployCheckpointArgs,
Expand Down Expand Up @@ -394,8 +400,18 @@ def download_training_job_data(


def download_checkpoint_artifacts(
remote_provider: BasetenRemote, job_id: Optional[str]
) -> Path:
remote_provider: BasetenRemote, job_id: Optional[str], view: Optional[str] = None
) -> Optional[Path]:
"""Download checkpoint artifacts or display them in a tree view.

Args:
remote_provider: The remote provider to use.
job_id: The job ID to get checkpoints for.
view: The view mode. If "tree", display as a tree. If None, save to JSON file.

Returns:
Path to the saved JSON file, or None if view mode is used.
"""
output_dir = Path.cwd()
job: dict

Expand All @@ -420,6 +436,15 @@ def download_checkpoint_artifacts(
if not checkpoint_artifacts:
raise click.ClickException("No checkpoints found for this training job.")

# If tree view is requested, display and return
if view == "tree":
timestamp = datetime.now(timezone.utc).isoformat()
_display_checkpoint_tree(
checkpoint_artifacts, job_id, project_id, project_name, timestamp
)
return None

# Otherwise, save to JSON file
output = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"job": job,
Expand All @@ -435,6 +460,44 @@ def download_checkpoint_artifacts(
return urls_file


def _display_checkpoint_tree(
checkpoint_artifacts: list[dict],
job_id: str,
project_id: str,
project_name: str,
timestamp: str,
) -> None:
# Convert checkpoint artifacts to FileInfo objects
file_infos = []
for artifact in checkpoint_artifacts:
# Extract the relative file name (path) and size from the artifact
path = artifact.get("relative_file_name", "unknown")
size_bytes = artifact.get("size_bytes", 0)
last_modified = artifact.get("last_modified")

file_infos.append(
FileInfo(path=path, size_bytes=size_bytes, modified=last_modified)
)

# Create visualization config
config = VisualizationConfig(
title=f"Checkpoints for Job: {job_id}",
metadata=VisualizationMetadata(
fields={
"Job ID": job_id,
"Project ID": project_id,
"Project Name": project_name,
"Retrieved At": timestamp,
}
),
files=file_infos,
)

# Display using the visualizer
visualizer = FileTreeVisualizer(config)
visualizer.display()


def status_page_url(remote_url: str, project_id: str, training_job_id: str) -> str:
return f"{remote_url}/training/{project_id}/logs/{training_job_id}"

Expand Down Expand Up @@ -673,6 +736,7 @@ def view_cache_summary(
project_id: str,
sort_by: str = SORT_BY_FILEPATH,
order: str = SORT_ORDER_ASC,
view: Optional[str] = None,
):
"""View cache summary for a training project."""
try:
Expand All @@ -684,64 +748,113 @@ def view_cache_summary(

cache_data = GetCacheSummaryResponseV1.model_validate(raw_cache_data)

table = rich.table.Table(title=f"Cache summary for project: {project_id}")
table.add_column("File Path", style="cyan")
table.add_column("Size", style="green")
table.add_column("Modified", style="yellow")
table.add_column("Type")
table.add_column("Permissions", style="magenta")

files = cache_data.file_summaries
if not files:
console.print("No files found in cache.", style="yellow")
return

files_with_total_sizes = create_file_summary_with_directory_sizes(files)
# Get project name
project = fetch_project_by_name_or_id(remote_provider, project_id)
project_name = project["name"]

# Apply sorting before converting to FileInfo
files_with_total_sizes = create_file_summary_with_directory_sizes(files)
reverse = order == SORT_ORDER_DESC
sort_key = _get_sort_key(sort_by)
files_with_total_sizes.sort(key=sort_key, reverse=reverse)

total_size = sum(
file_info.file_summary.size_bytes for file_info in files_with_total_sizes
)
total_size_str = common.format_bytes_to_human_readable(total_size)

console.print(
f"📅 Cache captured at: {cache_data.timestamp}", style="bold blue"
)
console.print(f"📁 Project ID: {cache_data.project_id}", style="bold blue")
console.print()
console.print(
f"📊 Total files: {len(files_with_total_sizes)}", style="bold green"
)
console.print(f"💾 Total size: {total_size_str}", style="bold green")
console.print()

for file_info in files_with_total_sizes:
total_size = file_info.total_size

size_str = cli_common.format_bytes_to_human_readable(int(total_size))

modified_str = cli_common.format_localized_time(
file_info.file_summary.modified
if view == "tree":
# Convert to FileInfo objects (using sorted list)
file_infos = [
FileInfo(
path=f.file_summary.path,
size_bytes=f.file_summary.size_bytes,
modified=f.file_summary.modified,
file_type=f.file_summary.file_type,
permissions=f.file_summary.permissions,
)
for f in files_with_total_sizes
]

# Create visualization config
config = VisualizationConfig(
title=f"Cache for Project: {project_name}",
metadata=VisualizationMetadata(
fields={
"Project ID": project_id,
"Captured At": cache_data.timestamp,
"Project Name": project_name,
}
),
files=file_infos,
)

table.add_row(
file_info.file_summary.path,
size_str,
modified_str,
file_info.file_summary.file_type or "Unknown",
file_info.file_summary.permissions or "Unknown",
# Display using the visualizer
visualizer = FileTreeVisualizer(config)
visualizer.display()
else:
# Default: Display as a simple table
_display_cache_summary_table(
files_with_total_sizes, project_name, project_id, cache_data.timestamp
)

console.print(table)

except Exception as e:
console.print(f"Error fetching cache summary: {str(e)}", style="red")
raise


def _display_cache_summary_table(
files_with_total_sizes: list[FileSummaryWithTotalSize],
project_name: str,
project_id: str,
timestamp: str,
) -> None:
"""Display cache summary in a simple table format.

Args:
files_with_total_sizes: List of files with their sizes.
project_name: The project name.
project_id: The project ID.
timestamp: The timestamp when the data was captured.
"""
# Display header
console.print(
f"\n📦 Cache Summary for Project: [bold cyan]{project_name}[/bold cyan]"
)
console.print(f"Project ID: {project_id}")
console.print(f"Captured At: {timestamp}\n")

# Create table
table = rich.table.Table(
show_header=True, header_style="bold magenta", box=rich.table.box.ROUNDED
)
table.add_column("Path", style="cyan")
table.add_column("Size", style="green", justify="right")
table.add_column("Type", style="yellow")
table.add_column("Modified", style="blue")
table.add_column("Permissions", style="white")

total_size = 0
for f in files_with_total_sizes:
size_str = common.format_bytes_to_human_readable(f.file_summary.size_bytes)
table.add_row(
f.file_summary.path,
size_str,
f.file_summary.file_type or "",
f.file_summary.modified or "",
f.file_summary.permissions or "",
)
total_size += f.file_summary.size_bytes

console.print(table)

# Display summary
console.print(f"\n📊 Total files: [bold]{len(files_with_total_sizes)}[/bold]")
console.print(
f"💾 Total size: [bold]{common.format_bytes_to_human_readable(total_size)}[/bold]\n"
)


def _get_sort_key(sort_by: str) -> Callable[[FileSummaryWithTotalSize], Any]:
if sort_by == SORT_BY_FILEPATH:
return lambda x: x.file_summary.path
Expand All @@ -762,7 +875,8 @@ def view_cache_summary_by_project(
project_identifier: str,
sort_by: str = SORT_BY_FILEPATH,
order: str = SORT_ORDER_ASC,
view: Optional[str] = None,
):
"""View cache summary for a training project by ID or name."""
project = fetch_project_by_name_or_id(remote_provider, project_identifier)
view_cache_summary(remote_provider, project["id"], sort_by, order)
view_cache_summary(remote_provider, project["id"], sort_by, order, view)
Loading
Loading