Skip to content

MilesCranmer/SymbolicRegression.jl

Repository files navigation

SymbolicRegression.jl searches for symbolic expressions which optimize a particular objective.

sr_animation.mp4
Latest release Documentation Forums Paper
version Dev Discussions Paper
Build status Coverage
CI Coverage Status
Aqua QA

Check out PySR for a Python frontend. Cite this software

Mark Kittisopikul
Mark Kittisopikul

💻 💡 🚇 📦 📣 👀 🔧 ⚠️
T Coxon
T Coxon

🐛 💻 🔌 💡 🚇 🚧 👀 🔧 ⚠️ 📓
Dhananjay Ashok
Dhananjay Ashok

💻 🌍 💡 🚧 ⚠️
Johan Blåbäck
Johan Blåbäck

🐛 💻 💡 🚧 📣 👀 ⚠️ 📓
JuliusMartensen
JuliusMartensen

🐛 💻 📖 🔌 💡 🚇 🚧 📦 📣 👀 🔧 📓
ngam
ngam

💻 🚇 📦 👀 🔧 ⚠️
Kaze Wong
Kaze Wong

🐛 💻 💡 🚇 🚧 📣 👀 🔬 📓
Christopher Rackauckas
Christopher Rackauckas

🐛 💻 🔌 💡 🚇 📣 👀 🔬 🔧 ⚠️ 📓
Patrick Kidger
Patrick Kidger

🐛 💻 📖 🔌 💡 🚧 📣 👀 🔬 🔧 ⚠️ 📓
Okon Samuel
Okon Samuel

🐛 💻 📖 🚧 💡 🚇 👀 ⚠️ 📓
William Booth-Clibborn
William Booth-Clibborn

💻 🌍 📖 📓 🚧 👀 🔧 ⚠️
Pablo Lemos
Pablo Lemos

🐛 💡 📣 👀 🔬 📓
Jerry Ling
Jerry Ling

🐛 💻 📖 🌍 💡 📣 👀 📓
Charles Fox
Charles Fox

🐛 💻 💡 🚧 📣 👀 🔬 📓
Johann Brehmer
Johann Brehmer

💻 📖 💡 📣 👀 🔬 ⚠️ 📓
Marius Millea
Marius Millea

💻 💡 📣 👀 📓
Coba
Coba

🐛 💻 💡 👀 📓
Pietro Monticone
Pietro Monticone

🐛 📖 💡
Mateusz Kubica
Mateusz Kubica

📖 💡
Jay Wadekar
Jay Wadekar

🐛 💡 📣 🔬
Anthony Blaom, PhD
Anthony Blaom, PhD

🚇 💡 👀
Jgmedina95
Jgmedina95

🐛 💡 👀
Michael Abbott
Michael Abbott

💻 💡 👀 🔧
Oscar Smith
Oscar Smith

💻 💡
Eric Hanson
Eric Hanson

💡 📣 📓
Henrique Becker
Henrique Becker

💻 💡 👀
qwertyjl
qwertyjl

🐛 📖 💡 📓
Rik Huijzer
Rik Huijzer

💡 🚇
Hongyu Wang
Hongyu Wang

💡 📣 🔬
Saurav Maheshkar
Saurav Maheshkar

🔧

Quickstart

Install in Julia with:

using Pkg
Pkg.add("SymbolicRegression")

MLJ Interface

The easiest way to use SymbolicRegression.jl is with MLJ. Let's see an example:

import SymbolicRegression: SRRegressor
import MLJ: machine, fit!, predict, report

# Dataset with two named features:
X = (a = rand(500), b = rand(500))

# and one target:
y = @. 2 * cos(X.a * 23.5) - X.b ^ 2

# with some noise:
y = y .+ randn(500) .* 1e-3

model = SRRegressor(
    niterations=50,
    binary_operators=[+, -, *],
    unary_operators=[cos],
)

Now, let's create and train this model on our data:

mach = machine(model, X, y)

fit!(mach)

You will notice that expressions are printed using the column names of our table. If, instead of a table-like object, a simple array is passed (e.g., X=randn(100, 2)), x1, ..., xn will be used for variable names.

Let's look at the expressions discovered:

report(mach)

Finally, we can make predictions with the expressions on new data:

predict(mach, X)

This will make predictions using the expression selected by model.selection_method, which by default is a mix of accuracy and complexity.

You can override this selection and select an equation from the Pareto front manually with:

predict(mach, (data=X, idx=2))

where here we choose to evaluate the second equation.

For fitting multiple outputs, one can use MultitargetSRRegressor (and pass an array of indices to idx in predict for selecting specific equations). For a full list of options available to each regressor, see the API page.

Low-Level Interface

The heart of SymbolicRegression.jl is the equation_search function. This takes a 2D array and attempts to model a 1D array using analytic functional forms. Note: unlike the MLJ interface, this assumes column-major input of shape [features, rows].

import SymbolicRegression: Options, equation_search

X = randn(2, 100)
y = 2 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2

options = Options(
    binary_operators=[+, *, /, -],
    unary_operators=[cos, exp],
    populations=20
)

hall_of_fame = equation_search(
    X, y, niterations=40, options=options,
    parallelism=:multithreading
)

You can view the resultant equations in the dominating Pareto front (best expression seen at each complexity) with:

import SymbolicRegression: calculate_pareto_frontier

dominating = calculate_pareto_frontier(hall_of_fame)

This is a vector of PopMember type - which contains the expression along with the score. We can get the expressions with:

trees = [member.tree for member in dominating]

Each of these equations is a Node{T} type for some constant type T (like Float32).

You can evaluate a given tree with:

import SymbolicRegression: eval_tree_array

tree = trees[end]
output, did_succeed = eval_tree_array(tree, X, options)

The output array will contain the result of the tree at each of the 100 rows. This did_succeed flag detects whether an evaluation was successful, or whether encountered any NaNs or Infs during calculation (such as, e.g., sqrt(-1)).

Constructing expressions

Expressions are represented as the Node type which is developed in the DynamicExpressions.jl package.

You can manipulate and construct expressions directly. For example:

import SymbolicRegression: Options, Node, eval_tree_array

options = Options(;
    binary_operators=[+, -, *, ^, /], unary_operators=[cos, exp, sin]
)
x1, x2, x3 = [Node(; feature=i) for i=1:3]
tree = cos(x1 - 3.2 * x2) - x1^3.2

This tree has Float64 constants, so the type of the entire tree will be promoted to Node{Float64}.

We can convert all constants (recursively) to Float32:

float32_tree = convert(Node{Float32}, tree)

We can then evaluate this tree on a dataset:

X = rand(Float32, 3, 100)
output, did_succeed = eval_tree_array(tree, X, options)

Exporting to SymbolicUtils.jl

We can view the equations in the dominating Pareto frontier with:

dominating = calculate_pareto_frontier(hall_of_fame)

We can convert the best equation to SymbolicUtils.jl with the following function:

import SymbolicRegression: node_to_symbolic

eqn = node_to_symbolic(dominating[end].tree, options)
println(simplify(eqn*5 + 3))

We can also print out the full pareto frontier like so:

import SymbolicRegression: compute_complexity, string_tree

println("Complexity\tMSE\tEquation")

for member in dominating
    complexity = compute_complexity(member, options)
    loss = member.loss
    string = string_tree(member.tree, options)

    println("$(complexity)\t$(loss)\t$(string)")
end

Code structure

SymbolicRegression.jl is organized roughly as follows. Rounded rectangles indicate objects, and rectangles indicate functions.

(if you can't see this diagram being rendered, try pasting it into mermaid-js.github.io/mermaid-live-editor)

flowchart TB
    op([Options])
    d([Dataset])
    op --> ES
    d --> ES
    subgraph ES[equation_search]
        direction TB
        IP[sr_spawner]
        IP --> p1
        IP --> p2
        subgraph p1[Thread 1]
            direction LR
            pop1([Population])
            pop1 --> src[s_r_cycle]
            src --> opt[optimize_and_simplify_population]
            opt --> pop1
        end
        subgraph p2[Thread 2]
            direction LR
            pop2([Population])
            pop2 --> src2[s_r_cycle]
            src2 --> opt2[optimize_and_simplify_population]
            opt2 --> pop2
        end
        pop1 --> hof
        pop2 --> hof
        hof([HallOfFame])
        hof --> migration
        pop1 <-.-> migration
        pop2 <-.-> migration
        migration[migrate!]
    end
    ES --> output([HallOfFame])
Loading

The HallOfFame objects store the expressions with the lowest loss seen at each complexity.

The dependency structure of the code itself is as follows:

stateDiagram-v2
    AdaptiveParsimony --> Mutate
    AdaptiveParsimony --> Population
    AdaptiveParsimony --> RegularizedEvolution
    AdaptiveParsimony --> SingleIteration
    AdaptiveParsimony --> SymbolicRegression
    CheckConstraints --> Mutate
    CheckConstraints --> SymbolicRegression
    Complexity --> CheckConstraints
    Complexity --> HallOfFame
    Complexity --> LossFunctions
    Complexity --> Mutate
    Complexity --> Population
    Complexity --> SearchUtils
    Complexity --> SingleIteration
    Complexity --> SymbolicRegression
    ConstantOptimization --> Mutate
    ConstantOptimization --> SingleIteration
    Core --> AdaptiveParsimony
    Core --> CheckConstraints
    Core --> Complexity
    Core --> ConstantOptimization
    Core --> HallOfFame
    Core --> InterfaceDynamicExpressions
    Core --> LossFunctions
    Core --> Migration
    Core --> Mutate
    Core --> MutationFunctions
    Core --> PopMember
    Core --> Population
    Core --> Recorder
    Core --> RegularizedEvolution
    Core --> SearchUtils
    Core --> SingleIteration
    Core --> SymbolicRegression
    Dataset --> Core
    HallOfFame --> SearchUtils
    HallOfFame --> SingleIteration
    HallOfFame --> SymbolicRegression
    InterfaceDynamicExpressions --> LossFunctions
    InterfaceDynamicExpressions --> SymbolicRegression
    LossFunctions --> ConstantOptimization
    LossFunctions --> HallOfFame
    LossFunctions --> Mutate
    LossFunctions --> PopMember
    LossFunctions --> Population
    LossFunctions --> SymbolicRegression
    Migration --> SymbolicRegression
    Mutate --> RegularizedEvolution
    MutationFunctions --> Mutate
    MutationFunctions --> Population
    MutationFunctions --> SymbolicRegression
    Operators --> Core
    Operators --> Options
    Options --> Core
    OptionsStruct --> Core
    OptionsStruct --> Options
    PopMember --> ConstantOptimization
    PopMember --> HallOfFame
    PopMember --> Migration
    PopMember --> Mutate
    PopMember --> Population
    PopMember --> RegularizedEvolution
    PopMember --> SingleIteration
    PopMember --> SymbolicRegression
    Population --> Migration
    Population --> RegularizedEvolution
    Population --> SearchUtils
    Population --> SingleIteration
    Population --> SymbolicRegression
    ProgramConstants --> Core
    ProgramConstants --> Dataset
    ProgressBars --> SearchUtils
    ProgressBars --> SymbolicRegression
    Recorder --> Mutate
    Recorder --> RegularizedEvolution
    Recorder --> SingleIteration
    Recorder --> SymbolicRegression
    RegularizedEvolution --> SingleIteration
    SearchUtils --> SymbolicRegression
    SingleIteration --> SymbolicRegression
    Utils --> CheckConstraints
    Utils --> ConstantOptimization
    Utils --> Options
    Utils --> PopMember
    Utils --> SingleIteration
    Utils --> SymbolicRegression
Loading

Bash command to generate dependency structure from src directory (requires vim-stream):

echo 'stateDiagram-v2'
IFS=$'\n'
for f in *.jl; do
    for line in $(cat $f | grep -e 'import \.\.' -e 'import \.'); do
        echo $(echo $line | vims -s 'dwf:d$' -t '%s/^\.*//g' '%s/Module//g') $(basename "$f" .jl);
    done;
done | vims -l 'f a--> ' | sort

Search options

See https://astroautomata.com/SymbolicRegression.jl/stable/api/#Options