Lesson 3.4: Taming the Tree - Pruning and Hyperparameter Tuning

We've identified the core problem: a single decision tree overfits by growing too complex. This lesson covers the practical solutions. We'll explore the two main strategies for controlling a tree's complexity: pre-pruning (stopping it early) via hyperparameter tuning, and post-pruning (cutting it back) via Cost-Complexity Pruning.

Part 1: The Two Philosophies of Control

To prevent our "over-eager intern" from building a ridiculously complex model, we need to give them some rules. There are two ways to do this:

1. Pre-Pruning (Early Stopping)

This is the most common and intuitive approach. We set rules that stop the tree from growing before it becomes too complex.

Analogy: We tell the intern, "You can't ask more than 5 questions," or "Don't create a rule for a group that has fewer than 10 customers."

2. Post-Pruning (Cutting Back)

This approach lets the tree grow to its full, overfit size, and then systematically cuts back branches that don't add much predictive value.

Analogy: We let the intern create their massive, convoluted rulebook. Then we go through it with a red pen and cross out all the silly, non-generalizable rules.

Part 2: Pre-Pruning via Hyperparameter Tuning

This is the most practical method used in machine learning libraries like Scikit-learn. We control the tree's growth by setting its **hyperparameters** before training. The goal is to find the combination of hyperparameters that performs best on our **validation set** (as discussed in Lesson 1.2).

The Key Hyperparameters for Decision Trees

  • max_depth: The most important one. It specifies the maximum number of levels the tree can have. A small max_depth (e.g., 3-5) creates a simple, interpretable model with low variance but potentially high bias.
  • min_samples_split: The minimum number of samples a node must have before it can be split. Setting this to a higher value prevents the model from learning from tiny, insignificant groups of data.
  • min_samples_leaf: The minimum number of samples a leaf node must have. This ensures that every final prediction is based on a reasonably large group of data points.
  • max_features: The number of features to consider when looking for the best split. By default, it considers all, but restricting it can help reduce variance. This is the key idea behind Random Forests.
Finding the Best Hyperparameters with GridSearchCV

We don't guess these values. We use a systematic process like **Grid Search with Cross-Validation** to find the optimal combination.

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV

# 1. Define the parameter grid
param_grid = {
    'max_depth': [3, 5, 7, 10, None],
    'min_samples_split': [2, 10, 20],
    'min_samples_leaf': [1, 5, 10]
}

# 2. Create the model and the GridSearchCV object
tree = DecisionTreeClassifier(random_state=42)
# cv=5 means 5-fold cross-validation
grid_search = GridSearchCV(estimator=tree, param_grid=param_grid, cv=5, n_jobs=-1, scoring='accuracy')

# 3. Fit it to your training data
# X_train, y_train come from your train_test_split
grid_search.fit(X_train, y_train)

# 4. Get the best model
print(f"Best parameters found: {grid_search.best_params_}")
best_tree = grid_search.best_estimator_

# Now use 'best_tree' to make predictions on your test set

Part 3: Post-Pruning with Cost-Complexity Pruning

While less common in day-to-day ML practice than pre-pruning, this is the classic, theoretically elegant method.

Minimal Cost-Complexity Pruning (CCP)

This method introduces a new hyperparameter, α\alpha (alpha), which controls a tradeoff between model fit and complexity.

For a tree T, we define a cost-complexity measure:

Rα(T)=Impurity(T)+αLeaves(T)R_\alpha(T) = \text{Impurity}(T) + \alpha \cdot |\text{Leaves}(T)|
  • Impurity(T)\text{Impurity}(T) is the total impurity of all the leaf nodes in the tree.
  • Leaves(T)|\text{Leaves}(T)| is the number of leaf nodes, which is our measure of the tree's complexity.
  • α\alpha is the complexity parameter. A higher α\alpha puts a heavier penalty on having more leaves, forcing the algorithm to choose simpler trees.

The algorithm finds a sequence of optimal subtrees for a range of α\alpha values. We then use cross-validation to find the single α\alpha (and its corresponding subtree) that gives the best out-of-sample performance.

What's Next? The Power of the Crowd

We now have a complete toolkit for building and controlling a single decision tree. We can tune its hyperparameters to find the optimal balance between bias and variance, resulting in a model that is both interpretable and reasonably predictive.

However, the real power of decision trees is not in using them one at a time. It's in combining them by the hundreds or thousands.

In **Module 4: The Power of the Crowd: Ensemble Methods**, we will take our single, pruned decision tree and use it as a "weak learner" building block to construct the two most powerful and widely used models for tabular data in all of machine learning: **Random Forest** and **Gradient Boosting Machines (XGBoost)**.