🤗ToolGen Model • 📄Paper (arxiv)
25/02/2025: Updated the training scripts. Refer to training/README.md for more details.
14/12/2024: Updated Qwen2.5-based ToolGen
ToolGen is a framework that integrates tool knowledge directly into LLMs by representing tools as unique tokens, enabling seamless tool invocation and language generation.🔧🦙 With 47,000 tool tokens, ToolGen shows superior performance in both tool retrieval and task completion.
The following code snippet shows how to run ToolGen locally. First, get your ToolBench key from ToolBench repo. Then deploy StableToolBench following the instructions in their repo.
import json
from OpenAgent import ToolGen
from OpenAgent import RapidAPIWrapper
# Initialize rapid api tools
with open("keys.json", 'r') as f:
    keys = json.load(f)
toolbench_key = keys['TOOLBENCH_KEY']
rapidapi_wrapper = RapidAPIWrapper(
    toolbench_key=toolbench_key,
    rapidapi_key="",
)
toolgen = ToolGen(
    "reasonwang/ToolGen-Llama-3-8B", # reasonwang/ToolGen-Qwen2.5-3B
    template="llama-3", # qwen-7b-chat
    indexing="Atomic",
    tools=rapidapi_wrapper,
)
messages = [
    {"role": "system", "content": ""},
    {"role": "user", "content": "I'm a football fan and I'm curious about the different team names used in different leagues and countries. Can you provide me with an extensive list of football team names and their short names? It would be great if I could access more than 7000 team names. Additionally, I would like to see the first 25 team names and their short names using the basic plan."}
]
toolgen.restart()
toolgen.start(
    single_chain_max_step=16,
    start_messages=messages
)Download and decompress data.tar.gz. Other datasets are at 🤗ToolGen-Datasets.
The first step is to map tools into tokens. We have extracted all the tools in ToolBench and converted them into tokens, as shown in virtual_tokens.txt. The following code adds the tokens into the vocabulary and expands model embeddings.
with open('data/virtual_tokens.txt', 'r') as f:
    virtual_tokens = f.readlines()
    virtual_tokens = [unidecode(vt.strip()) for vt in virtual_tokens]
model_name_or_path = "meta-llama/Meta-Llama-3-8B"
# Load tokenizer and add tokens into vocabulary
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.add_tokens(new_tokens=virtual_tokens, special_tokens=False)
# Load model and expand embeddings
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.bfloat16
)
model.resize_token_embeddings(len(tokenizer))
combined_tokens = []
for vt in virtual_tokens:
    combined_token = vt[2:-2].split("&&")
    combined_tokens.append(combined_token)
    
for combined_token, virtual_token in zip(combined_tokens, virtual_tokens):
    combined_token_ids = tokenizer(" ".join(combined_token), add_special_tokens=False).input_ids
    virtual_token_id = tokenizer(virtual_token, add_special_tokens=False).input_ids
    assert len(virtual_token_id) == 1
    combined_token_embeddings = model.model.embed_tokens(torch.tensor(combined_token_ids).to(model.device))
    embedding = torch.mean(combined_token_embeddings, dim=0)
    model.model.embed_tokens.weight.data[virtual_token_id[0]] = embeddingAfter tool virtualization, there is a three-stage training to finetune ToolGen. The first stage is tool memorization, which trains the model to memorize all tool tokens. The data for this stage is at 🤗ToolGen-Memorization. We have converted the format into ShareGPT-like format for an easy integration with current training framework like FastChat and LLaMA-Factory. Note that we train the first stage for 8 epochs. A sample is shown bellow:
{
    "conversations": [
        {
            "role": "user",
            "content": "Tool Name: QRCheck. Tool Description: Check the quality of any QRCode Api Name: quality_v1_quality_post Api Description: None.",
            "loss": false
        },
        {
            "role": "assistant",
            "content": "<<QRCheck&&quality_v1_quality_post>>",
            "loss": true
        }
    ]
}
The second stage mainly trains the tool retrieval capability of ToolGen. The data is also at 🤗ToolGen-Retrieval. We train it for 1 epoch. After the second stage training, we obtain ToolGen-Retriever. A sample is shown below:
{
    "conversations": [
        {
            "role": "user",
            "content": "My friends and I are organizing a hackathon on 'web development' and 'mobile app development'. We need some inspiration and guidance. Can you fetch the top stories on these topics from Medium.com?",
            "loss": false,
        },
        {
            "role": "assistant",
            "content": "<<Medium&&/search/topics>>",
            "loss": true
        }
    ]
}
Finally, we train the ToolGen with agent trajectories to enable them task completion capability. The data is at 🤗ToolGen-Agent.
The following command shows an example to evaluate the retrieval performance. Other tool retrieval evaluation scripts can be found in scripts/retrieval.
python -m evaluation.retrieval.eval_toolgen \
    --model_name_or_path "reasonwang/ToolGen-Llama-3-8B-Tool-Retriever" \
    --indexing "Atomic" \
    --stage "G1" \
    --split "instruction" \
    --result_path data/results/retrieval/ \
    --constrain True
For end-to-end evaluation, first get ToolBench Key and run StableToolBench.
Then, perform inference on queries to generate trajectories. Scripts can be found in scripts/inference
First, using scripts/convert_answer/run_convert_answer.sh to convert trajectory format. Then run scripts/pass_rate/run_pass_rate.sh for pass rate evaluation.
Run scripts/preference/run_preference.sh for win rate evaluation.
If our work is helpful, please kindly cite as:
@misc{wang2024toolgenunifiedtoolretrieval,
      title={ToolGen: Unified Tool Retrieval and Calling via Generation}, 
      author={Renxi Wang and Xudong Han and Lei Ji and Shu Wang and Timothy Baldwin and Haonan Li},
      year={2024},
      eprint={2410.03439},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2410.03439}, 
}
