Open In App

Changing Colors for Decision Tree Plot Using Sklearn plot_tree

Last Updated : 12 Sep, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

Decision trees are a popular machine learning model used for classification and regression tasks. Visualizing decision trees can provide insights into the decision-making process of the model, making it easier to interpret and explain the results. Scikit-learn, a widely used machine learning library in Python, offers a convenient method called plot_tree for visualizing decision trees. This article will guide you through the process of customizing the colors of decision tree plots using plot_tree from scikit-learn.

Understanding the Basics of plot_tree in Scikit-learn

Before diving into color customization, let's briefly review the basic usage of sklearn's plot_tree function. This function is part of the sklearn.tree module and provides a straightforward way to visualize decision trees. It generates a plot that represents the structure of the decision tree, including the nodes and the decision rules.

The function requires matplotlib to be installed, and it allows for various customizations, including the ability to change colors, node shapes, and more.

To use plot_tree, you need to have a trained decision tree model. Here's a basic example:

Python
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

# Load dataset and train a decision tree classifier
iris = load_iris()
X, y = iris.data, iris.target
clf = DecisionTreeClassifier()
clf.fit(X, y)

# Plot the decision tree
plt.figure(figsize=(20, 10))
plot_tree(clf, filled=True)
plt.show()

Output:

decisiontree
plot_tree in Scikit-learn

The plot_tree function takes several parameters, including the trained classifier, feature names, class names, and a boolean 'filled' parameter that determines whether nodes should be colored based on the majority class.

Changing Node Colors

One of the most common customizations is changing the colors of the nodes in the decision tree plot. By default, when 'filled=True', plot_tree uses a predefined color scheme. However, we can modify this to suit our needs.

To change node colors, and more fine-grained control, you can define a custom function to assign colors based on specific criteria:

Python
# Custom Node Color Function
def node_color_function(value):
    if value < 0.3:
        return "lightblue"
    elif value < 0.7:
        return "lightgreen"
    else:
        return "pink"

# Custom color function applied to nodes
fig, ax = plt.subplots(figsize=(20,10))
tree.plot_tree(clf, 
               feature_names=feature_names, 
               class_names=class_names,
               filled=True, 
               ax=ax)  # plot_tree doesn't directly support custom node functions, so this step remains simple
plt.show()

Output:

custom
Changing Node Colors

Enhancing Readability with Contrasting Colors

When dealing with complex decision trees, it's crucial to ensure that the text remains readable against the node colors. One way to achieve this is by dynamically setting the text color based on the background color:

The function calculates the luminance of the background color and sets the text color to white or black accordingly, ensuring optimal contrast

Python
# Enhancing Readability with Contrasting Text Colors
import matplotlib.colors as mcolors
def get_text_color(bg_color):
    rgb = mcolors.to_rgb(bg_color)
    luminance = 0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]
    return 'white' if luminance < 0.5 else 'black'

# Plot the tree and apply text contrast
fig, ax = plt.subplots(figsize=(20,10))
tree.plot_tree(clf, 
               feature_names=feature_names, 
               class_names=class_names,
               filled=True, 
               ax=ax)

# Adjust text colors based on background luminance
for text in ax.texts:
    bbox = text.get_bbox_patch()
    if bbox:
        bg_color = bbox.get_facecolor()
        text.set_color(get_text_color(bg_color))
plt.show()

Output:

readibility
Enhancing Readability with Contrasting Colors

Creating a Color Legend

To make your visualization more informative, it's often helpful to include a color legend. While plot_tree doesn't provide this functionality directly, we can add it using matplotlib:

Python
# Adding a Custom Color Legend
from matplotlib.patches import Patch
fig, ax = plt.subplots(figsize=(20,10))
tree.plot_tree(clf, 
               feature_names=feature_names, 
               class_names=class_names,
               filled=True, 
               ax=ax)

# Create a custom legend
legend_elements = [Patch(facecolor='lightblue', edgecolor='black', label='Class 0'),
                   Patch(facecolor='lightgreen', edgecolor='black', label='Class 1'),
                   Patch(facecolor='pink', edgecolor='black', label='Class 2')]

# Add the legend to the plot
ax.legend(handles=legend_elements, loc='lower right')
plt.show()

Output:

legend
Creating a Color Legend

Highlighting Important Features

In some cases, you might want to highlight nodes that use specific features. Here's how you can achieve this:

Python
important_features = ['petal length (cm)', 'petal width (cm)']  # Features to highlight

# Plot the decision tree
fig, ax = plt.subplots(figsize=(20, 10))
tree_plot = plot_tree(clf, feature_names=feature_names, class_names=iris.target_names, filled=True, ax=ax)

# Highlight important features by changing node colors
for i, node in enumerate(tree_plot):
    node_feature = clf.tree_.feature[i]  # Get feature index of the node
    if node_feature != -2:  # -2 means the node is a leaf
        feature_name = feature_names[node_feature]
        if feature_name in important_features:
            node.get_bbox_patch().set_facecolor("pink")  # Highlight important feature nodes

# Adding legend for highlighted features
legend_elements = [mpatches.Patch(facecolor='yellow', edgecolor='black', label='Important Feature')]
ax.legend(handles=legend_elements, loc='upper right')

plt.show()

Output:

highlight
Highlighting Important Features

Conclusion

Customizing colors in decision tree plots using sklearn's plot_tree function can significantly enhance the interpretability and visual appeal of your visualizations. By leveraging matplotlib's capabilities and implementing custom color functions, you can create decision tree plots that not only look great but also convey information more effectively.


Next Article

Similar Reads