Skip to content

Commit

Permalink
Initial working version of multiclass
Browse files Browse the repository at this point in the history
  • Loading branch information
siboehm committed Aug 30, 2021
1 parent 52c15c2 commit e644556
Show file tree
Hide file tree
Showing 12 changed files with 2,582 additions and 72 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ llvm_model.compile()
- Drop-in replacement: The interface of `lleaves.Model` is a subset of `LightGBM.Booster`.
- Dependencies: `llvmlite` and `numpy`. LLVM comes statically linked.

Some LightGBM features are not yet implemented: multiclass prediction, linear models.

## Installation
`conda install -c conda-forge lleaves` or `pip install lleaves` (Linux and MacOS only).

Expand Down
15 changes: 15 additions & 0 deletions docs/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@ An example from the *model.txt* of the airlines model::

The bitvectors of the first three categorical nodes are <1 x i32>, <1 x i32> and <8 x i32> long.

Multiclass prediction
*********************

Multiclass prediction works by basically fitting individual forests for each class, and then running a
softmax across the outputs.
So for 3 classes with 100 iterations LightGBM will generate 300 trees.
The trees are saved in the model.txt in strides, like so::

tree 0 # (=class 0, tree 0)
tree 1 # (=class 1, tree 0)
tree 2 # (=class 2, tree 0)
tree 3 # (=class 0, tree 1)
tree 4 # (=class 1, tree 1)
...

Software Architecture Overview
------------------------------

Expand Down
46 changes: 23 additions & 23 deletions lleaves/compiler/ast/nodes.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
from lleaves.compiler.utils import DecisionType
from dataclasses import dataclass


class Forest:
def __init__(
self,
trees: list,
features: list,
objective_func: str,
objective_func_config: str,
):
self.trees = trees
self.n_args = len(features)
self.features = features
self.objective_func = objective_func
self.objective_func_config = objective_func_config
class Node:
@property
def is_leaf(self):
return isinstance(self, LeafNode)


@dataclass
class Tree:
def __init__(self, idx, root_node, features):
self.idx = idx
self.root_node = root_node
self.features = features
idx: int
root_node: Node
features: list
class_id: int

def __str__(self):
return f"tree_{self.idx}"


class Node:
@dataclass
class Forest:
trees: list[Tree]
features: list
n_classes: int
objective_func: str
objective_func_config: str

@property
def is_leaf(self):
return isinstance(self, LeafNode)
def n_args(self):
return len(self.features)


class DecisionNode(Node):
Expand Down Expand Up @@ -74,10 +74,10 @@ def __str__(self):
return f"node_{self.idx}"


@dataclass
class LeafNode(Node):
def __init__(self, idx, value):
self.idx = idx
self.value = value
idx: int
value: float

def __str__(self):
return f"leaf_{self.idx}"
19 changes: 13 additions & 6 deletions lleaves/compiler/ast/parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

from lleaves.compiler.ast.nodes import DecisionNode, Forest, LeafNode, Tree
from lleaves.compiler.ast.scanner import scan_model_file
from lleaves.compiler.utils import DecisionType
Expand All @@ -18,7 +20,7 @@ def __init__(self, is_categorical):
self.is_categorical = is_categorical


def _parse_tree_to_ast(tree_struct, features):
def _parse_tree_to_ast(tree_struct, features, class_id):
n_nodes = len(tree_struct["decision_type"])
leaves = [
LeafNode(idx, value) for idx, value in enumerate(tree_struct["leaf_value"])
Expand Down Expand Up @@ -78,17 +80,19 @@ def _parse_tree_to_ast(tree_struct, features):
node.validate()

if nodes:
return Tree(tree_struct["Tree"], nodes[0], features)
return Tree(tree_struct["Tree"], nodes[0], features, class_id)
else:
# special case for when tree is just single leaf
assert len(leaves) == 1
return Tree(tree_struct["Tree"], leaves[0], features)
return Tree(tree_struct["Tree"], leaves[0], features, class_id)


def parse_to_ast(model_path):
scanned_model = scan_model_file(model_path)

n_args = scanned_model["general_info"]["max_feature_idx"] + 1
n_classes = scanned_model["general_info"]["num_class"]
assert n_classes == scanned_model["general_info"]["num_tree_per_iteration"]
objective = scanned_model["general_info"]["objective"]
objective_func = objective[0]
objective_func_config = objective[1] if len(objective) > 1 else None
Expand All @@ -99,10 +103,13 @@ def parse_to_ast(model_path):
assert n_args == len(features), "Ill formed model file"

trees = [
_parse_tree_to_ast(tree_struct, features)
for tree_struct in scanned_model["trees"]
_parse_tree_to_ast(scanned_tree, features, class_id)
for scanned_tree, class_id in zip(
scanned_model["trees"], itertools.cycle(range(n_classes))
)
]
return Forest(trees, features, objective_func, objective_func_config)
assert len(trees) % n_classes == 0, "Ill formed model file"
return Forest(trees, features, n_classes, objective_func, objective_func_config)


def is_categorical_feature(feature_info: str):
Expand Down
2 changes: 2 additions & 0 deletions lleaves/compiler/ast/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(self, type: type, is_list=False, null_ok=False):

INPUT_SCAN_KEYS = {
"max_feature_idx": ScannedValue(int),
"num_class": ScannedValue(int),
"num_tree_per_iteration": ScannedValue(int),
"version": ScannedValue(str),
"feature_infos": ScannedValue(str, True),
"objective": ScannedValue(str, True),
Expand Down
79 changes: 59 additions & 20 deletions lleaves/compiler/codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import dataclass

from llvmlite import ir

from lleaves.compiler.utils import ISSUE_ERROR_MSG, MissingType
Expand Down Expand Up @@ -28,6 +30,14 @@ def dconst(value):
return ir.Constant(DOUBLE, value)


@dataclass
class LTree:
"""Class for the LLVM function of a tree paired with relevant non-LLVM context"""

llvm_function: ir.Function
class_id: int


def gen_forest(forest, module):
"""
Populate the passed IR module with code for the forest.
Expand Down Expand Up @@ -81,10 +91,14 @@ def make_tree(tree):
tree_func.linkage = "private"
# populate function with IR
gen_tree(tree, tree_func)
return tree_func
return LTree(llvm_function=tree_func, class_id=tree.class_id)

tree_funcs = [make_tree(tree) for tree in forest.trees]

if forest.n_classes > 1:
# better locality by running trees for each class together
tree_funcs.sort(key=lambda t: t.class_id)

_populate_forest_func(forest, root_func, tree_funcs)


Expand Down Expand Up @@ -189,17 +203,30 @@ def _populate_instruction_block(
else:
args.append(el)
# iterate over each tree, sum up results
res = builder.call(tree_funcs[0], args)
for func in tree_funcs[1:]:
tree_res = builder.call(func, args)
res = builder.fadd(tree_res, res)
ptr = builder.gep(out_arr, (loop_iter_reg,))
res = builder.fadd(res, builder.load(ptr))
results = [dconst(0.0) for _ in range(forest.n_classes)]
for func in tree_funcs:
tree_res = builder.call(func.llvm_function, args)
results[func.class_id] = builder.fadd(tree_res, results[func.class_id])
res_idx = builder.mul(iconst(forest.n_classes), loop_iter_reg)
results_ptr = [
builder.gep(out_arr, (builder.add(res_idx, iconst(class_idx)),))
for class_idx in range(forest.n_classes)
]

results = [
builder.fadd(result, builder.load(result_ptr))
for result, result_ptr in zip(results, results_ptr)
]
if eval_obj_func:
res = _populate_objective_func_block(
builder, res, forest.objective_func, forest.objective_func_config
results = _populate_objective_func_block(
builder,
results,
forest.objective_func,
forest.objective_func_config,
)
builder.store(res, ptr)
for result, result_ptr in zip(results, results_ptr):
builder.store(result, result_ptr)

tmpp1 = builder.add(loop_iter_reg, iconst(1))
builder.store(tmpp1, loop_iter)
builder.branch(condition_block)
Expand Down Expand Up @@ -230,7 +257,7 @@ def _populate_forest_func(forest, root_func, tree_funcs):


def _populate_objective_func_block(
builder, input, objective: str, objective_config: str
builder, args, objective: str, objective_config: str
):
"""
Takes the objective function specification and generates the code for it into the builder
Expand All @@ -246,23 +273,23 @@ def _populate_sigmoid(alpha):
raise ValueError(f"Sigmoid parameter needs to be >0, is {alpha}")

# 1 / (1 + exp(- alpha * x))
inner = builder.fmul(dconst(-alpha), input)
inner = builder.fmul(dconst(-alpha), args[0])
exp = builder.call(llvm_exp, [inner])
denom = builder.fadd(dconst(1.0), exp)
return builder.fdiv(dconst(1.0), denom)

if objective == "binary":
alpha = objective_config.split(":")[1]
return _populate_sigmoid(float(alpha))
result = _populate_sigmoid(float(alpha))
elif objective in ("xentropy", "cross_entropy"):
return _populate_sigmoid(1.0)
result = _populate_sigmoid(1.0)
elif objective in ("xentlambda", "cross_entropy_lambda"):
# naive implementation which will be numerically unstable for small x.
# should be changed to log1p
exp = builder.call(llvm_exp, [input])
return builder.call(llvm_log, [builder.fadd(dconst(1.0), exp)])
exp = builder.call(llvm_exp, [args[0]])
result = builder.call(llvm_log, [builder.fadd(dconst(1.0), exp)])
elif objective in ("poisson", "gamma", "tweedie"):
return builder.call(llvm_exp, [input])
result = builder.call(llvm_exp, [args[0]])
elif objective in (
"regression",
"regression_l1",
Expand All @@ -272,15 +299,27 @@ def _populate_sigmoid(alpha):
"mape",
):
if objective_config and "sqrt" in objective_config:
return builder.call(llvm_copysign, [builder.fmul(input, input), input])
arg = args[0]
result = builder.call(llvm_copysign, [builder.fmul(arg, arg), arg])
else:
return input
result = args[0]
elif objective in ("lambdarank", "rank_xendcg", "custom"):
return input
result = args[0]
elif objective == "multiclass":
assert len(args)
# TODO Check vectorization / vectorize by hand
result = [builder.call(llvm_exp, [arg]) for arg in args]

denominator = dconst(0.0)
for r in result:
denominator = builder.fadd(r, denominator)

result = [builder.fdiv(r, denominator) for r in result]
else:
raise ValueError(
f"Objective '{objective}' not yet implemented. {ISSUE_ERROR_MSG}"
)
return result if len(args) > 1 else [result]


def _populate_categorical_node_block(
Expand Down
26 changes: 15 additions & 11 deletions lleaves/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,24 @@ def extract_pandas_traintime_categories(file_path):
raise ValueError("Ill formatted model file!")


def extract_num_feature(file_path):
def extract_n_features_n_classes(file_path):
"""
Extract number of features expected by this model as 'max_feature_idx' + 1
Extract number of features and the number of classes of this model
:param file_path: path to model.txt
:return: the number of features expected by this model.
:return: dict with "n_args": number of features, "n_classes": number of classes
"""
res = {}
with open(file_path, "r") as f:
line = f.readline()
while line and not line.startswith("max_feature_idx"):
for _ in range(2):
line = f.readline()

if line.startswith("max_feature_idx"):
n_args = int(line.split("=")[1]) + 1
else:
raise ValueError("Ill formatted model file!")
return n_args
while line and not line.startswith(("max_feature_idx", "num_class")):
line = f.readline()

if line.startswith("max_feature_idx"):
res["n_feature"] = int(line.split("=")[1]) + 1
elif line.startswith("num_class"):
res["n_class"] = int(line.split("=")[1])
else:
raise ValueError("Ill formatted model file!")
return res
Loading

0 comments on commit e644556

Please sign in to comment.