A python library used for model structure interpretation.
Right now the library contains logic for DecisionTreeClassifier, DecisionTreeRegression and RandomForestClassifier from scikit-learn.
Next versions of the library will contain other types of algorithms, like
RandomForestRegressor, XGboost.
Becoming a better machine learning engineer is important to understand more deeply the model structure and also to have an intuition of what is happening if we change the model inputs, how these will reflect in model performance. By model inputs we mean to add more data, add new features and to change model hyperparameters
This library was developed with two main ideas in mind :
- help us better understand the model structure, the model results and based on this to properly choose others hyperparameter values, other set of features for the next iteration
- to justify/explain the predictions of ML models both for technical and non technical people
pip install git+https://github.com/tlapusan/woodpecker.git
The well known titanic dataset was chosen to show library capabilities.
features = ["Pclass", "Age", "Fare", "Sex_label", "Cabin_label", "Embarked_label"]
target = "Survived"
Let's see some descriptive statistics about training set.
model = DecisionTreeClassifier(criterion="entropy", random_state=random_state, min_samples_split=20) model.fit(train[features], train[target])
dts = DecisionTreeStructure(model, train, features, target)
You don't have to type all the code needed to extract feature importance, to map them to feature names and to sort them. Now, you just type this simple utility function.
Like in the above case, this function is also an utility function what wrap all the code needed to visualize decision tree structure using graphviz.
Impurity is a metric which shows how confident is your leaf prediction.
In case of entropy, impurity is a range of values between 0 and 1.
0 means that the leaf node is very confident about its predictions, 1 means the opposite.
The tree performance is directly influenced by each leaf performance. So it's very important to have a general overview of how leaves impurity look.
dts.show_leaf_impurity_distribution(bins=40, figsize=(20, 7))
Sample is a metric which shows how many examples from training set reached that node.
For a leaf is ideal to have an impurity very close to 0, but it's also equally important
to have a significant set of samples reaching that leaf. If the set of samples is very small, could be a sign
of outfitting for the leaf.
That's why is important to look both at leaves impurity (previous plot) and samples to get a better understanding of tree performance.
dts.show_leaf_samples_distribution(bins=40, figsize=(20, 7))
There could be the case when we want to investigate individual leaf behavior.
We could analyze leaves with very good, medium or very low performance.
plt.subplot(3,1,3) dts.show_leaf_samples_by_class()
This function return a dataframe with all training samples reaching a node. After looking at individual leaves metrics, we can see that there are some interesting leaves. For example the leaf 19 has impurity 0, a lot of samples and all people survived (survived=1) Getting the samples from such a leaf can help us to discover patterns in data or to discover why a leaf has good/bad performance.
dts.get_node_samples(node_id=19)[features + [target]].describe()
We can see that majority of people were from a high social economic status (Pclass = 1), most of them were young to mid age, bought an expensive ticket (mean(Fare) from training is 32) and are all women.
There will be moments when we need to justify why our model predicted a specific value. Looking at the whole tree and tracking the path prediction is not time effective if the depth of the tree is large.
Let's look at prediction path for the following sample :
Pclass 3.0
Age 28.0
Fare 15.5
Sex_label 0.0
Cabin_label -1.0
Embarked_label 1.0
This visualization shows the training data splits the model was build. It can be used also as a way to learn how decision tree was built.
The sample is the same as above.
dts.show_decision_tree_splits_prediction(sample, bins=20)
For other algorithms visualizations, you can take a look inside the notebooks folder
- 0.1
- model structure investigation for DecisionTreeClassifier
- 0.2
- add visualisation for correct/wrong leaves predictions
- add setup.py file
Tudor Lapusan
twitter : @tlapusan
email : tudor.lapusan@gmail.com
- jupyter
- matplotlib
- scikit-learn
- pandas
This project is licensed under the terms of the MIT license, see LICENSE.