Skip to content

Inference boosted tree #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: inference-processor-simple
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,15 @@
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.MemoryUsageEstimationProcessManager;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.inference.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.ModelLoader;
import org.elasticsearch.xpack.ml.inference.sillymodel.SillyModelLoader;
import org.elasticsearch.xpack.ml.inference.tree.TreeModelLoader;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
Expand Down Expand Up @@ -628,7 +629,8 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
}

private Map<String, ModelLoader> getModelLoaders(Client client) {
return Map.of(SillyModelLoader.MODEL_TYPE, new SillyModelLoader(client));
return Map.of(SillyModelLoader.MODEL_TYPE, new SillyModelLoader(client),
TreeModelLoader.MODEL_TYPE, new TreeModelLoader(client));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.inference.tree;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.BiPredicate;


/**
* A decision tree that can make predictions given a feature vector
*/
public class Tree {
private final List<Node> nodes;

Tree(List<Node> nodes) {
this.nodes = Collections.unmodifiableList(nodes);
}

/**
* Trace the route predicting on the feature vector takes.
* @param features The feature vector
* @return The list of traversed nodes ordered from root to leaf
*/
public List<Node> trace(List<Double> features) {
return trace(features, 0, new ArrayList<>());
}

private List<Node> trace(List<Double> features, int nodeIndex, List<Node> visited) {
Node node = nodes.get(nodeIndex);
visited.add(node);
if (node.isLeaf()) {
return visited;
}

int nextNode = node.compare(features);
return trace(features, nextNode, visited);
}

/**
* Make a prediction based on the feature vector
* @param features The feature vector
* @return The prediction
*/
public Double predict(List<Double> features) {
return predict(features, 0);
}

private Double predict(List<Double> features, int nodeIndex) {
Node node = nodes.get(nodeIndex);
if (node.isLeaf()) {
return node.value();
}

int nextNode = node.compare(features);
return predict(features, nextNode);
}

/**
* Finds {@code null} nodes. If constructed properly there should be no {@code null} nodes.
* {@code null} nodes indicates missing leaf or junction nodes
*
* @return List of indexes to the {@code null} nodes
*/
List<Integer> missingNodes() {
List<Integer> nullNodeIndices = new ArrayList<>();
for (int i=0; i<nodes.size(); i++) {
if (nodes.get(i) == null) {
nullNodeIndices.add(i);
}
}
return nullNodeIndices;
}

@Override
public String toString() {
return nodes.toString();
}

public static class Node {
int leftChild;
int rightChild;
int featureIndex;
boolean isDefaultLeft;
double thresholdValue;
BiPredicate<Double, Double> operator;

Node(int leftChild, int rightChild, int featureIndex, boolean isDefaultLeft, double thresholdValue) {
this.leftChild = leftChild;
this.rightChild = rightChild;
this.featureIndex = featureIndex;
this.isDefaultLeft = isDefaultLeft;
this.thresholdValue = thresholdValue;
this.operator = (value, threshold) -> value < threshold; // less than
}

Node(int leftChild, int rightChild, int featureIndex, boolean isDefaultLeft, double thresholdValue,
BiPredicate<Double, Double> operator) {
this.leftChild = leftChild;
this.rightChild = rightChild;
this.featureIndex = featureIndex;
this.isDefaultLeft = isDefaultLeft;
this.thresholdValue = thresholdValue;
this.operator = operator;
}

Node(double value) {
this(-1, -1, -1, false, value);
}

public boolean isLeaf() {
return leftChild < 1;
}

int compare(List<Double> features) {
Double feature = features.get(featureIndex);
if (isMissing(feature)) {
return isDefaultLeft ? leftChild : rightChild;
}

return operator.test(feature, thresholdValue) ? leftChild : rightChild;
}

boolean isMissing(Double feature) {
return feature == null;
}

public Double value() {
return thresholdValue;
}

public int getFeatureIndex() {
return featureIndex;
}

public int getLeftChild() {
return leftChild;
}

public int getRightChild() {
return rightChild;
}

@Override
public String toString() {
StringBuilder builder = new StringBuilder("{\n");
builder.append("left: ").append(leftChild).append('\n');
builder.append("right: ").append(rightChild).append('\n');
builder.append("isDefaultLeft: ").append(isDefaultLeft).append('\n');
builder.append("isLeaf: ").append(isLeaf()).append('\n');
builder.append("featureIndex: ").append(featureIndex).append('\n');
builder.append("value: ").append(thresholdValue).append('\n');
builder.append("}\n");
return builder.toString();
}
}


public static class TreeBuilder {

private final ArrayList<Node> nodes;
private int numNodes;

public static TreeBuilder newTreeBuilder() {
return new TreeBuilder();
}

TreeBuilder() {
nodes = new ArrayList<>();
// allocate space in the root node and set to a leaf
nodes.add(null);
addLeaf(0, 0.0);
numNodes = 1;
}

/**
* Add a decision node. Space for the child nodes is allocated
* @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index
* @param featureIndex The feature index the decision is made on
* @param isDefaultLeft Default left branch if the feature is missing
* @param decisionThreshold The decision threshold
* @return The created node
*/
public Node addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) {
int leftChild = numNodes++;
int rightChild = numNodes++;
nodes.ensureCapacity(nodeIndex +1);
for (int i=nodes.size(); i<nodeIndex +1; i++) {
nodes.add(null);
}

Node node = new Node(leftChild, rightChild, featureIndex, isDefaultLeft, decisionThreshold);
nodes.set(nodeIndex, node);

// allocate space for the child nodes
while (nodes.size() <= rightChild) {
nodes.add(null);
}

return node;
}

/**
* Sets the node at {@code nodeIndex} to a leaf node.
* @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)}
* @param value The prediction value
* @return this
*/
public TreeBuilder addLeaf(int nodeIndex, double value) {
for (int i=nodes.size(); i<nodeIndex +1; i++) {
nodes.add(null);
}

assert nodes.get(nodeIndex) == null : "expected null value at index " + nodeIndex;

nodes.set(nodeIndex, new Node(value));
return this;
}

public Tree build() {
return new Tree(nodes);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.inference.tree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class TreeEnsembleModel {

private final List<Tree> trees;
private final Map<String, Integer> featureMap;

private TreeEnsembleModel(List<Tree> trees, Map<String, Integer> featureMap) {
this.trees = Collections.unmodifiableList(trees);
this.featureMap = featureMap;
}

public int numFeatures() {
return featureMap.size();
}

public int numTrees() {
return trees.size();
}

public List<Integer> checkForNull() {
List<Integer> missing = new ArrayList<>();
for (Tree tree : trees) {
missing.addAll(tree.missingNodes());
}
return missing;
}

public double predictFromDoc(Map<String, Object> features) {
List<Double> featureVec = docToFeatureVector(features);
List<Double> predictions = trees.stream().map(tree -> tree.predict(featureVec)).collect(Collectors.toList());
return mergePredictions(predictions);
}

public double predict(Map<String, Double> features) {
List<Double> featureVec = doubleDocToFeatureVector(features);
List<Double> predictions = trees.stream().map(tree -> tree.predict(featureVec)).collect(Collectors.toList());
return mergePredictions(predictions);
}

public List<List<Tree.Node>> trace(Map<String, Double> features) {
List<Double> featureVec = doubleDocToFeatureVector(features);
return trees.stream().map(tree -> tree.trace(featureVec)).collect(Collectors.toList());
}

double mergePredictions(List<Double> predictions) {
return predictions.stream().mapToDouble(f -> f).summaryStatistics().getSum();
}

List<Double> doubleDocToFeatureVector(Map<String, Double> features) {
List<Double> featureVec = Arrays.asList(new Double[featureMap.size()]);

for (Map.Entry<String, Double> keyValue : features.entrySet()) {
if (featureMap.containsKey(keyValue.getKey())) {
featureVec.set(featureMap.get(keyValue.getKey()), keyValue.getValue());
}
}

return featureVec;
}

List<Double> docToFeatureVector(Map<String, Object> features) {
List<Double> featureVec = Arrays.asList(new Double[featureMap.size()]);

for (Map.Entry<String, Object> keyValue : features.entrySet()) {
if (featureMap.containsKey(keyValue.getKey())) {
Double value = (Double)keyValue.getValue();
if (value != null) {
featureVec.set(featureMap.get(keyValue.getKey()), value);
}
}
}

return featureVec;
}

public static ModelBuilder modelBuilder(Map<String, Integer> featureMap) {
return new ModelBuilder(featureMap);
}

public static class ModelBuilder {
private List<Tree> trees;
private Map<String, Integer> featureMap;

public ModelBuilder(Map<String, Integer> featureMap) {
this.featureMap = featureMap;
trees = new ArrayList<>();
}

public ModelBuilder addTree(Tree tree) {
trees.add(tree);
return this;
}

public TreeEnsembleModel build() {
return new TreeEnsembleModel(trees, featureMap);
}
}
}
Loading