Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 149 additions & 42 deletions examples/tool_use/generate_with_tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import os
import re
import time
import uuid
from typing import Any, Dict, List

Expand All @@ -9,6 +11,9 @@

# Import jupyter tool functionality
from .jupyter_tool import SEMAPHORE, tool_registry
from .statistic_metrics import get_init_metrics

debug = os.getenv("SLIME_DEBUG", "False").lower() == "true"

def format_prompt(state: GenerateState, prompt: str, tool_specs: List[Dict[str, Any]]) -> str:
# if prompt has already applied chat template
Expand Down Expand Up @@ -73,20 +78,31 @@ def postprocess_predictions(prediction: str):

return None, ""

async def execute_predictions(session_id: str, state: GenerateState, prediction: str) -> str:
async def execute_predictions(session_id: str, state: GenerateState, prediction: str, log_dict: dict) -> str:
"""Execute predictions and return results"""
tool_wait_lock_time, tool_execution_time, tool_wait_lock_and_execution_times = 0.0, 0.0, 0.0

action, content = postprocess_predictions(prediction)

if action == "tool_call_error":
next_obs = state.tokenizer.apply_chat_template([{"role": "tool", "content": content}],
add_generation_prompt=True, tokenize=False)
done = False
if action == "code":
elif action == "code":
# Content is already the Python code (extracted by postprocess_predictions)
code = content.strip()
if code:
tool_wait_lock_start_time = time.time()
async with SEMAPHORE:
tool_wait_lock_end_time = time.time()
tool_wait_lock_time = tool_wait_lock_end_time - tool_wait_lock_start_time

tool_execution_start_time = time.time()
result = await tool_registry.execute_tool("python", {"code": code}, session_id)
tool_execution_end_time = time.time()

tool_execution_time = tool_execution_end_time - tool_execution_start_time
tool_wait_lock_and_execution_times = tool_execution_end_time - tool_wait_lock_start_time
else:
result = "Error: No Python code found"

Expand All @@ -100,10 +116,14 @@ async def execute_predictions(session_id: str, state: GenerateState, prediction:
next_obs = ""
done = True

log_dict["tool_wait_lock_times"].append(tool_wait_lock_time)
log_dict["tool_execution_times"].append(tool_execution_time)
log_dict["tool_wait_lock_and_execution_times"].append(tool_wait_lock_and_execution_times)

return next_obs, done

def postprocess_sample(sample: Sample, prompt_token_ids: List[int], response_token_ids: List[int],
loss_masks: List[int], response: str, max_new_tokens: int, tokenizer) -> Sample:
loss_masks: List[int], response: str, max_new_tokens: int, tokenizer, log_dict: dict) -> Sample:
if len(response_token_ids) > max_new_tokens:
response_token_ids = response_token_ids[:max_new_tokens]
response = tokenizer.decode(response_token_ids, skip_special_tokens=False)
Expand All @@ -116,15 +136,96 @@ def postprocess_sample(sample: Sample, prompt_token_ids: List[int], response_tok
sample.response = response
sample.loss_mask = loss_masks

# # Store payload information for wandb logging
# sample.payload_text = prompt + response
# sample.payload_has_system = "<|im_start|>system" in prompt + response
# sample.payload_has_tools = "# Tools" in prompt + response

# # Store tool call count for reward calculation
# sample.tool_call_count = tool_call_count
sample.train_metadata = postprocess_log_dict(sample, log_dict)
print(f"[session_id: {log_dict.get('session_id')}, group idx: {sample.group_index}, sample idx: {sample.index}] "\
f"debug log dict: {sample.train_metadata}")
return sample

def postprocess_log_dict(sample: Sample, log_dict: dict) -> dict:
print_prefix = f"[session_id: {log_dict.get('session_id')}, group idx: {sample.group_index}, sample idx: {sample.index}]"

log_dict["total_length"] = len(sample.prompt) + len(sample.response)
log_dict["total_token_length"] = len(sample.tokens)
log_dict["response_length"] = len(sample.response)
log_dict["response_token_length"] = sample.response_length
log_dict["tool_call_count"] = sample.response.count("<tool_response>")

eps = 1e-8 # to avoid division by zero
# round statistics
if len(log_dict["round_total_times"]) > len(log_dict["sgl_generation_times"]):
log_dict["sgl_generation_times"].extend(
[0.0] * (len(log_dict["round_total_times"]) - len(log_dict["sgl_generation_times"]))
)
assert len(log_dict["round_total_times"]) == len(log_dict["sgl_generation_times"]), \
f"{print_prefix} len(round_total_times): {len(log_dict['round_total_times'])}, len(sgl_generation_times): "\
f"{len(log_dict['sgl_generation_times'])} are not equal."
# 每轮SGLang生成时间占该轮生成回复总时间的比例
log_dict["sgl_generation_time_ratios"] = (
np.array(log_dict["sgl_generation_times"]) / (np.array(log_dict["round_total_times"]) + eps)
).tolist()

assert len(log_dict["tool_execution_times"]) == len(log_dict["tool_wait_lock_times"]) == len(log_dict["tool_wait_lock_and_execution_times"]), \
f"{print_prefix} len(tool_execution_times): {len(log_dict['tool_execution_times'])}, len(tool_wait_lock_times): {len(log_dict['tool_wait_lock_times'])}, "\
f"len(tool_wait_lock_and_execution_times): {len(log_dict['tool_wait_lock_and_execution_times'])} are not equal."
if len(log_dict["round_total_times"]) > len(log_dict["tool_execution_times"]):
extend_len = len(log_dict["round_total_times"]) - len(log_dict["tool_execution_times"])
log_dict["tool_execution_times"].extend([0.0] * extend_len)
log_dict["tool_wait_lock_times"].extend([0.0] * extend_len)
log_dict["tool_wait_lock_and_execution_times"].extend([0.0] * extend_len)
assert len(log_dict["round_total_times"]) == len(log_dict["tool_wait_lock_and_execution_times"]), \
f"{print_prefix} len(round_total_times): {len(log_dict['round_total_times'])}, len(tool_wait_lock_and_execution_times): "\
f"{len(log_dict['tool_wait_lock_and_execution_times'])} are not equal."

# 每轮tool调用时间占该轮总tool调用时间的比例
log_dict["tool_execution_time_ratios_for_tool_time"] = (
np.array(log_dict["tool_execution_times"]) / (np.array(log_dict["tool_wait_lock_and_execution_times"]) + eps)
).tolist()
# 每轮等待获取执行tool许可信号量的总时间占所有tool调用时间的比例
log_dict["tool_wait_lock_time_ratios_for_tool_time"] = (
np.array(log_dict["tool_wait_lock_times"]) / (np.array(log_dict["tool_wait_lock_and_execution_times"]) + eps)
).tolist()
# 每轮等待获取执行tool许可信号量+执行tool时间占该轮生成回复总时间的比例
log_dict["tool_wait_lock_and_execution_time_ratios"] = (
np.array(log_dict["tool_wait_lock_and_execution_times"]) / (np.array(log_dict["round_total_times"]) + eps)
).tolist()

# sample statistics
log_dict["total_tool_execution_time"] = sum(log_dict["tool_execution_times"])
log_dict["total_tool_wait_lock_time"] = sum(log_dict["tool_wait_lock_times"])
log_dict["total_tool_wait_lock_and_execution_time"] = sum(log_dict["tool_wait_lock_and_execution_times"])
log_dict["total_sgl_generation_time"] = sum(log_dict["sgl_generation_times"])
log_dict["total_time"] = sum(log_dict["round_total_times"])
# 执行tool消耗的总时间占所有tool调用时间的比例
log_dict["total_tool_execution_time_ratio_for_total_tool_time"] = (
log_dict["total_tool_execution_time"] / (log_dict["total_tool_wait_lock_and_execution_time"] + eps)
)
# 等待获取执行tool许可信号量的总时间占所有tool调用时间的比例
log_dict["total_tool_wait_lock_time_ratio_for_total_tool_time"] = (
log_dict["total_tool_wait_lock_time"] / (log_dict["total_tool_wait_lock_and_execution_time"] + eps)
)
# 等待获取执行tool许可信号量+执行tool的总时间占生成回复总时间的比例
log_dict["total_tool_wait_lock_and_execution_time_ratio"] = (
log_dict["total_tool_wait_lock_and_execution_time"] / (log_dict["total_time"] + eps)
)
# SGLang生成的总时间占生成回复总时间的比例
log_dict["total_sgl_generation_time_ratio"] = (
log_dict["total_sgl_generation_time"] / (log_dict["total_time"] + eps)
)

return log_dict

def report_wandb(log_dict: dict):
report_dict = {
f"debug/{k}": v for k, v in log_dict.items() if not isinstance(v, list)
}
try:
import wandb

if wandb.run is not None:
wandb.log(report_dict)
except ImportError:
pass # wandb not available

async def generate(args, sample: Sample, sampling_params) -> Sample:
"""Custom generation function supporting tool calls"""
assert not args.partial_rollout, "Partial rollout is not supported for this function at the moment."
Expand All @@ -137,14 +238,17 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
), f"Sample status is {sample.status}"

session_id = "generate_" + uuid.uuid4().hex
debug = os.getenv("SLIME_DEBUG", "False").lower() == "true"

# Set up the initial prompt with system prompt and tools (outside the loop)
tool_specs = tool_registry.get_tool_specs()

# Count available tools (from tool_specs)
available_tools = len(tool_specs)

prompt = format_prompt(state, sample.prompt, tool_specs)
if debug:
print(f"sample.prompt:\n {sample.prompt}\nFormatted prompt:\n {prompt}")
print(f"[session_id: {session_id}] sample.prompt:\n {sample.prompt}\nFormatted prompt:\n {prompt}")
# convert sample.prompt to formatted prompt
sample.prompt = prompt

prompt_token_ids = state.tokenizer(prompt, add_special_tokens=False)["input_ids"]
response = ""
Expand All @@ -154,8 +258,13 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
max_new_tokens = sampling_params["max_new_tokens"]
turn = 0

debug_log_dict = get_init_metrics()
debug_log_dict["available_tools"] = available_tools
debug_log_dict["session_id"] = session_id

try:
while True:
round_start_time = time.time()
sampling_params["max_new_tokens"] = max_new_tokens - len(response_token_ids)

if sampling_params["max_new_tokens"] <= 0:
Expand All @@ -165,45 +274,24 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:

sample.status = Sample.Status.TRUNCATED
return postprocess_sample(sample, prompt_token_ids, response_token_ids, loss_masks, response,
max_new_tokens, state.tokenizer)
max_new_tokens, state.tokenizer, debug_log_dict)

# Prepare payload for sglang server
payload = {
"input_ids": prompt_token_ids + response_token_ids,
"sampling_params": sampling_params,
}

# Log payload to wandb for debugging
try:
import wandb

if wandb.run is not None:
# Count available tools (from tool_specs)
available_tools = len(tool_specs)
# Count tools used in the current response
tools_used = response.count("<tool_call>")

wandb.log(
{
"debug/total_length": len(prompt + response),
"debug/total_token_length": len(prompt_token_ids + response_token_ids),
"debug/response_length": len(response),
"debug/response_token_length": len(response_token_ids),
"debug/available_tools": available_tools,
"debug/tools_used": tools_used,
"debug/turn": turn,
}
)
except ImportError:
pass # wandb not available

sgl_generation_start_time = time.time()
output = await post(url, payload)
_log_duration_time(sgl_generation_start_time, debug_log_dict, "sgl_generation_times")

# Handle abort
if output["meta_info"]["finish_reason"]["type"] == "abort":
_log_duration_time(round_start_time, debug_log_dict, "round_total_times")
sample.status = Sample.Status.ABORTED
return postprocess_sample(sample, prompt_token_ids, response_token_ids, loss_masks, response,
max_new_tokens, state.tokenizer)
max_new_tokens, state.tokenizer, debug_log_dict)

cur_response_token_ids = output["output_ids"]
cur_response = state.tokenizer.decode(cur_response_token_ids, skip_special_tokens=False)
Expand All @@ -216,10 +304,12 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:

# Check length limit
if output["meta_info"]["finish_reason"]["type"] == "length":
_log_duration_time(round_start_time, debug_log_dict, "round_total_times")
break

next_obs, done = await execute_predictions(session_id, state, cur_response)
next_obs, done = await execute_predictions(session_id, state, cur_response, debug_log_dict)
if done:
_log_duration_time(round_start_time, debug_log_dict, "round_total_times")
break
if debug:
# 观察tool_response apply_chat_template后的输出结果
Expand All @@ -236,6 +326,9 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
loss_masks += [0] * len(obs_tokens_ids)
turn += 1

debug_log_dict["turn"] = turn
_log_duration_time(round_start_time, debug_log_dict, "round_total_times")

# Set status
match output["meta_info"]["finish_reason"]["type"]:
case "length":
Expand All @@ -246,13 +339,27 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
sample.status = Sample.Status.COMPLETED

sample = postprocess_sample(sample, prompt_token_ids, response_token_ids, loss_masks, response,
max_new_tokens, state.tokenizer)
max_new_tokens, state.tokenizer, debug_log_dict)
if debug:
print(f"sample: {sample}")
print(f"[session_id: {session_id}] sample: {sample}")
finally:
# close jupyter session
result = await tool_registry.jupyter_client.end_session(session_id)
if debug:
print(f"[session_id: {session_id}] End session result: {result}")

return sample

# 记录持续时间
def _log_duration_time(start_time: float, log_dict: dict, key: str) -> float:
end_time = time.time()
duration = end_time - start_time

if log_dict.get(key) is None:
pass
elif isinstance(log_dict.get(key), list):
log_dict[key].append(duration)
else:
log_dict[key] += duration

return end_time
10 changes: 5 additions & 5 deletions examples/tool_use/jupyter_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import asyncio
import httpx
import os
import requests
import traceback
Expand All @@ -26,9 +27,8 @@

class JupyterToolClient:
"""Client for interacting with the Jupyter tool service"""
def __init__(self, server_url: str, http_timeout: int = 600):
# request timeout in seconds
self.http_timeout = http_timeout
def __init__(self, server_url: str, http_timeout: int = 300):
self.http_client = httpx.AsyncClient(timeout=httpx.Timeout(http_timeout))

self.default_server_url = "http://localhost:8000"
self.server_url = server_url or os.getenv("JUPYTER_TOOL_SERVER_URL", self.default_server_url)
Expand All @@ -46,7 +46,7 @@ async def execute_code(self, session_id: str, code: str) -> str:
payload = {"session_id": session_id, "code": code}

try:
response = requests.post(url, json=payload, headers=headers, timeout=self.http_timeout)
response = await self.http_client.post(url, json=payload or {}, headers=headers)
if response.status_code == 200:
return response.json().get("output", "")
else:
Expand All @@ -68,7 +68,7 @@ async def end_session(self, session_id: str) -> str:
}

try:
response = requests.post(url, headers=headers, timeout=self.http_timeout)
response = await self.http_client.post(url, headers=headers)
if response.status_code == 200:
return response.json().get("output", "")
else:
Expand Down
Loading
Loading