Skip to content

A Tree Search Library with Flexible API for LLM Inference-Time Scaling

License

SakanaAI/treequest

Repository files navigation

TreeQuest

Python GitHub license Checks status Tests status

arXiv Blog

AB-MCTS

A flexible answer tree search library featuring AB-MCTS, useful for (but not limited to) LLM inference-time scaling.

Quick Start

import random

import treequest as tq

# Each node is associated with a user-definable `state`.
State = str

# 1. Define a function to be used for node generation.
def generate(parent_state: State | None) -> tuple[State, float]:
    """Generates new states and scores based on the parent state."""
    if parent_state is None: # None represents the expansion from root.
        new_state = "Initial state"
    else:
        new_state = f"State after {parent_state}"

    score = random.random() # A score for the new state; It should be normalized to the [0, 1] range.
    return new_state, score

# 2. Instantiate the algorithm and a search tree object.
algo = tq.ABMCTSA()
search_tree = algo.init_tree()

# 3. Run the search with a generation budget (10 in this case).
for _ in range(10):
    search_tree = algo.step(search_tree, {'Action A': generate})

# 4. Extract the best score and state.
best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best state: {best_state}, Score: {best_node_score}")

Alternatively, you can use an ask–tell interface with batched AB-MCTS sampling steps:

import random
import treequest as tq

State = str

def generate(parent_state: State | None) -> tuple[State, float]:
    ...

generate_fns = {"Action A": generate}
actions = list(generate_fns.keys())

# We use batch_size=5 here
batch_size = 5

# It runs AB-MCTS sampling step with 5 processes in parallel
algo = tq.ABMCTSM(max_process_workers=batch_size)
search_tree = algo.init_tree()

total_budget = 50
num_steps = total_budget // batch_size
for _ in range(num_steps):
    # ask_batch returns a list of `Trial` object, which has action, parent_state and trial_id attrs
    search_tree, trials = algo.ask_batch(search_tree, batch_size, actions)

    for trial in trials:
        result = generate_fns[trial.action](trial.parent_state)
        # Call tell method with trial_id to update search_tree
        search_tree = algo.tell(search_tree, trial.trial_id, result)

best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]

In particular for AB-MCTS-M, each step call can be slow. If you encounter slow execution, prefer ask_batch over step. Please note that using a large batch_size can skew the search-tree shape (i.e., the tree may become too wide), so it is best to avoid overly large batch_size, see PROFILING.md for example trees. We recommend batch_size<=5 as a starting point.

Features

  • Easy-to-use API with customizable node generation and node scoring logic.
  • AB-MCTS-A and AB-MCTS-M, as well as Multi-LLM AB-MCTS support (See our paper for algorithm details).
  • Checkpointing and resuming searches.

Installation

uv

First, install uv. Then you can install TreeQuest with the following command:

uv add "treequest[abmcts-m]"

pip

Alternatively, you can use pip to install TreeQuest:

pip install "treequest[abmcts-m]"

Usage

Using an LLM as a Node Generator

You can use any object as a node state. You only need to define a generating function that returns a (state, score) tuple and takes the parent state as an argument:

import dataclasses

import treequest as tq

@dataclasses.dataclass
class State:
    llm_answer: str
    score: float

def generate(parent_state: State | None) -> tuple[State, float]:
    """Generate a new node by calling an LLM."""
    if parent_state is None:
        state = initial_generation()
    else:
        state = refine_answer(parent_state.llm_answer, parent_state.score)

    return state, state.score
    
def initial_generation() -> State:
    """
    Call LLM API to generate an initial answer.
    """
    ...

def refine_answer(llm_answer: str, score: float) -> State:
    """
    Call LLM API to refine an answer.
    """
    ...


algo = tq.ABMCTSM()
search_tree = algo.init_tree()
for i in range(20):
    search_tree = algo.step(search_tree, {'Action Label': generate})
    # Logging best node during the search.
    if (i + 1) % 5 == 0:
        best_interim_state, _ = tq.top_k(search_tree, algo, k=1)[0]
        print(f"Iteration {i+1}: Best state so far = {best_interim_state}")

best_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best Answer: {best_state.llm_answer}, Best Score: {best_state.score}")

Using Multiple LLMs (and Beyond)

TreeQuest supports multiple action types. For example, you can provide multiple generation functions backed by different LLMs to represent different action types:

from functools import partial

import treequest as tq

def generate(llm_name: str, parent_state=None):
    """
    Call LLM API using litellm, vllm, etc., to generate a new node
    """
    ...
    return new_state, new_score

llm_names = ["o4-mini", "gemini-2.5-pro"]
# Create dict of different actions backed by different LLMs.
generate_fns = {llm_name: partial(generate, llm_name=llm_name) for llm_name in llm_names}

algo = tq.StandardMCTS()
search_tree = algo.init_tree()
for _ in range(20):
    search_tree = algo.step(search_tree, generate_fns)

The variation is not limited to LLM types; you can use different prompts, actions, scoring logic, etc. in generate_fns.

Batch Semantics and Concurrency

  • Algorithms are stateless objects; the evolving tree/search state is returned from init_tree, step, ask, and tell.
  • ask_batch(state, batch_size, actions) returns exactly batch_size Trial objects to expand next.
    • Non-queue algorithms (e.g., ABMCTSM, ABMCTSA, MultiArmedBanditUCB) return exactly batch_size Trials.
    • Queue-based algorithms (e.g., StandardMCTS, BestFirstSearchAlgo, TreeOfThoughtsBFS) precompute a set of parent/action pairs and duplicate them if needed to fill batch_size.
  • tell(state, trial_id, (new_state, score)) reflects the result for the corresponding Trial.
    • Order-independent: you can call tell in any order; reflection is tied to trial_id.
    • Idempotent: calling tell twice on the same trial_id does not add extra nodes.
    • For queue-based algorithms, over-told Trials beyond possible number of childs from a parent node (e.g., (# actions)*samples_per_action for StandardMCTS) become INVALID and are not reflected.
  • Scores are expected to be normalized to the [0, 1] range.

Algorithms

ABMCTS-A: ABMCTS with Node Aggregation

ABMCTS-A uses node aggregation for adaptive branching:

import treequest as tq

# Instantiate the ABMCTS-A algorithm.
ab_mcts_a = tq.ABMCTSA()

search_tree = ab_mcts_a.init_tree()
for _ in range(50):
    search_tree = ab_mcts_a.step(search_tree, generate_fns)

ABMCTS-M: ABMCTS with Mixed Models

ABMCTS-M leverages PyMC's mixed modeling capabilities:

import treequest as tq

# Instantiate the ABMCTS-M algorithm.
ab_mcts_m = tq.ABMCTSM()

search_tree = ab_mcts_m.init_tree()
for _ in range(30):
    search_tree = ab_mcts_m.step(search_tree, generate_fns)

NOTE: To run AB-MCTS-M, you need to install extra dependencies with the treequest[abmcts-m] option.

Requirements

  • Python 3.11+

Contributing

Contributions are welcome! Please see CONTRIBUTING.md for development tips.

Citation

@article{inoue2025wider,
  title={Wider or Deeper? Scaling LLM Inference-Time Compute with Adaptive Branching Tree Search},
  author={Inoue, Yuichi and Misaki, Kou and Imajuku, Yuki and Kuroki, So and Nakamura, Taishi and Akiba, Takuya},
  journal={arXiv preprint arXiv:2503.04412},
  year={2025}
}

License

Apache 2.0

About

A Tree Search Library with Flexible API for LLM Inference-Time Scaling

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published

Languages