How to extract the decision rules from scikit-learn decision-tree?
Categories:
Extracting Decision Rules from Scikit-learn Decision Trees

Learn various methods to interpret and visualize the decision rules embedded within scikit-learn Decision Tree and Random Forest models for better understanding and debugging.
Decision trees are powerful and intuitive machine learning models, prized for their interpretability. However, once trained, extracting the exact decision rules that the model learned can sometimes be challenging, especially for complex trees or ensembles like Random Forests. This article explores several techniques to extract, visualize, and understand these rules from scikit-learn's DecisionTreeClassifier
and DecisionTreeRegressor
.
Understanding Decision Tree Structure
A decision tree makes predictions by traversing a series of nodes, each representing a decision based on a feature. It starts at the root node and follows branches until it reaches a leaf node, which provides the final prediction. Each path from the root to a leaf represents a specific decision rule. Understanding this structure is key to extracting the rules.
flowchart TD A[Root Node: Feature X <= Threshold 1?] -->|Yes| B{Node 1: Feature Y <= Threshold 2?} A -->|No| C{Node 2: Feature Z <= Threshold 3?} B -->|Yes| D[Leaf Node: Prediction 1] B -->|No| E[Leaf Node: Prediction 2] C -->|Yes| F[Leaf Node: Prediction 3] C -->|No| G[Leaf Node: Prediction 4]
Simplified Decision Tree Structure
Method 1: Exporting to Graphviz
The most common and visually appealing way to understand a scikit-learn decision tree is to export it to the DOT format and render it using Graphviz. This provides a graphical representation of the entire tree, making it easy to trace decision paths.
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.datasets import load_iris
import graphviz
# Load dataset
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
target_names = iris.target_names
# Train a Decision Tree Classifier
dtree = DecisionTreeClassifier(max_depth=3, random_state=42)
dtree.fit(X, y)
# Export to DOT format
dot_data = export_graphviz(
dtree,
out_file=None,
feature_names=feature_names,
class_names=target_names,
filled=True,
rounded=True,
special_characters=True
)
# Render with Graphviz
graph = graphviz.Source(dot_data)
graph.render("iris_decision_tree", view=True, format='png') # Saves and opens the image
Exporting a Decision Tree to Graphviz
pip install graphviz
and install the system package) for this method to work. The view=True
argument will attempt to open the generated image file.Method 2: Programmatic Rule Extraction
For more detailed analysis or when you need to process rules programmatically, you can traverse the tree structure directly using the tree_
attribute of the trained DecisionTreeClassifier
or DecisionTreeRegressor
object. This attribute exposes arrays that describe the tree's nodes, such as children_left
, children_right
, feature
, threshold
, and value
.
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
# Load dataset
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
target_names = iris.target_names
# Train a Decision Tree Classifier
dtree = DecisionTreeClassifier(max_depth=3, random_state=42)
dtree.fit(X, y)
def get_rules(tree, feature_names, class_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != -2 else "undefined!"
for i in tree_.feature
]
def recurse(node, depth, parent_rule):
indent = " " * depth
if tree_.feature[node] != -2: # Not a leaf node
threshold = tree_.threshold[node]
feature = feature_name[node]
# Left child rule
left_rule = f"{parent_rule} AND {feature} <= {threshold:.2f}" if parent_rule else f"{feature} <= {threshold:.2f}"
recurse(tree_.children_left[node], depth + 1, left_rule)
# Right child rule
right_rule = f"{parent_rule} AND {feature} > {threshold:.2f}" if parent_rule else f"{feature} > {threshold:.2f}"
recurse(tree_.children_right[node], depth + 1, right_rule)
else: # Leaf node
class_idx = tree_.value[node].argmax()
class_label = class_names[class_idx]
print(f"Rule: {parent_rule} => Predict: {class_label}")
recurse(0, 0, "")
print("\nExtracted Decision Rules:")
get_rules(dtree, feature_names, target_names)
Programmatic Extraction of Decision Rules
Method 3: Visualizing Individual Trees in a Random Forest
Random Forests are ensembles of decision trees. While you can't easily visualize the entire forest as one, you can extract and visualize individual trees from the forest. This is useful for understanding the diversity of decisions made by different trees within the ensemble.
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.tree import export_graphviz
import graphviz
# Load dataset
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
target_names = iris.target_names
# Train a Random Forest Classifier
rf_model = RandomForestClassifier(n_estimators=10, max_depth=3, random_state=42)
rf_model.fit(X, y)
# Extract and visualize the first tree from the forest
estimator = rf_model.estimators_[0] # Get the first tree
dot_data = export_graphviz(
estimator,
out_file=None,
feature_names=feature_names,
class_names=target_names,
filled=True,
rounded=True,
special_characters=True
)
graph = graphviz.Source(dot_data)
graph.render("random_forest_tree_0", view=True, format='png')
Visualizing a Single Tree from a Random Forest