XGBoost as an Interaction Explorer

This post explores one approach to transforming tree-based models, in this case XGBoost, into linear models. In this framework, XGBoost can be seen as a method for exploring interactions between features and regions of data.

Ben Ewing https://improperprior.com/


Tree-based models, like XGBoost and Random Forests, can be incredibly powerful tools for modeling structured data, however they are are not typically considered interpretable. While you can inspect the trees, look at feature importance measures and partial-dependence plots, or use explanation methods like LIME and SHAP, these approaches are not perfect (Slack et al. 2020) and are often not sufficient for high-stakes decisions (Rudin 2019) or regulatorily complex environments. In this post, I will show one approach for using XGBoost to generate interpretable embeddings that can be used to improve linear model performance.

This post presumes some knowledge about decision trees, but not necessarily any specific tree-growing method.

The Main Idea

A single decision tree is just a set of decision points ending at terminal nodes called a leaf. Within a tree, each data point will be sorted into a specific leaf; each leaf can be thought of as a representation of a partition of the data. This representation naturally includes interactions between features, making it especially powerful. While a single decision tree can be inspected and interpreted directly (just look at each decision), models like XGBoost and Random Forests rely on fitting many trees.

The general idea presented here is that we can fit XGBoost with relatively shallow trees. Rather than getting predictions from each of the trees, we get the leaf that each data point is represented by. At this point we will have quite a few leaves, exactly how many will depend on the number of trees, depth, and pruning parameters, but it will likely be too many to be considered interpretable. It can be pretty easy to end up with more leaves than observations! To reduce the number of trees we can use a LASSO regression to select the most informative leaves.

We fit with shallow trees to aid with interpretability, but exactly what constitutes shallow will be subjective (what do you feel like you can interpret correctly?) and may be domain-specific. To prevent overfitting, it’s best to train the XGBoost and LASSO models on separate datasets.

This is not an original idea by any means! I first learned of this approach when reading about the Loft Data Science Team’s XGBoost Survival Embeddings (Loft 2021). The Scikit-learn documentation also discusses a similar approach, and shows how it can be incorporated in a pipeline (Head 2021).

An Example

For the remainder of this post I’ll demonstrate the application of this approach. However, I want to highlight that this approach may not always be the best way to discover interesting embeddings (e.g. as compared to domain knowledge), or may result in no linear model improvement at all.

I’ll be using R for this blog post, but the approach described here is language agnostic. I’ve also done some model tuning that is not included in this post for the sake of focus and brevity.

# Data manipulation
# I/O
# Modeling
# Parallel things
# Visualization

# Settings/misc things
mc_metrics <- metric_set(roc_auc, accuracy)

Austin Housing Data

I’ll be using the Austin housing data used in the Sliced Semifinals dataset. This dataset contains characteristics and binned sale price of 10,000 homes in Austin, TX. The outcome of interest is the sale price, but to make life simpler, I’ll focus on predicting whether or not the price of a given home is above $450,000. While likely very useful, I’m going to ignore the city and description columns, and collapse the homeType column.

Load the data.

# Data is read in elsewhere, because file paths :)
# Encode the outcome and drop uneeded vars
df <- df %>% 
  mutate(y = factor(priceRange %in% c("450000-650000", "650000+"),
                    c(T, F), c('above_450000', 'below_450000')),
         homeType = fct_other(homeType, keep = c('Single Family'))) %>%
  select(-c(city, description, priceRange))

# Split into sets - for XGBoost, regression, and testing
df_split <- initial_split(df, prop = 0.5, strata = y)
df_train_split <- initial_split(training(df_split), strata = y)
df_train_xgb <- training(df_train_split)
df_train_lr <- testing(df_train_split)
df_test <- testing(df_split)

Baseline Linear Model

Let’s specify a baseline linear model. Note that because this model is linear and I have no reason to believe that latitude and longitude are linearly related the price range, I am going to omit those variables.

lr_recipe <- recipe(y ~ hasSpa + garageSpaces + yearBuilt +
                      numOfPatioAndPorchFeatures + lotSizeSqFt +
                      avgSchoolRating + MedianStudentsPerTeacher + 
                      numOfBathrooms + numOfBedrooms + homeType,
                    data = df_train_lr)

lr_spec <- logistic_reg(mode = 'classification', engine = 'glm')

lr_wf <- workflow() %>% 
  add_recipe(lr_recipe) %>% 

Now we can fit and get performance estimates.

df_train_lr_folds <- vfold_cv(df_train_lr, v = 10, strata = y)

fit_resamples(lr_wf, resamples = df_train_lr_folds) %>%
  collect_metrics() %>% 
.metric .estimator mean n std_err .config
accuracy binary 0.7761685 10 0.0048137 Preprocessor1_Model1
roc_auc binary 0.8380181 10 0.0059774 Preprocessor1_Model1

This is not terrible, but there’s plenty of room for improvement.

Fitting XGBoost and Choosing Leaves

Where is there room for improvement? Well, we omitted a pretty big feature: the location of the home. We did this because there’s no reason to believe, a priori, that home value should increase with latitude or longitude. However, there is reason to believe that latitude and longitude do have some relationship to home price; in other words, neighborhoods exist. Luckily for us, XGBoost is pretty good at learning non-linear relationships. Let’s fit XGBoost using just longitude and latitude, and see if we can effectively use those leaf-embeddings in a downstream model.

Specify the model. Note that tree depth was set for intepretability, and number of trees is set with my laptop’s performance in mind; the other parameters are the result of some light tuning.

lat_lon_recipe <- recipe(y ~ latitude + longitude,
                         data = df_train_xgb)

xgb_spec <- boost_tree(mode = 'classification', engine = 'xgboost',
                       trees = 100, tree_depth = 2, mtry = 2,
                       min_n = 5, sample_size = 0.412, learn_rate = 0.0997)
xgb_wf <- workflow() %>% 
  add_recipe(lat_lon_recipe) %>% 

Fit on resamples and get performance estimates.

df_train_xgb_folds <- vfold_cv(df_train_xgb, v = 10, strata = y)

fit_resamples(xgb_wf, resamples = df_train_xgb_folds) %>%
  collect_metrics() %>% 
.metric .estimator mean n std_err .config
accuracy binary 0.7927581 10 0.0041516 Preprocessor1_Model1
roc_auc binary 0.8749483 10 0.0038339 Preprocessor1_Model1

Performance is a bit better than the logistic regression, and especially not bad considering we’re only using two features! Including other features would certainly improve performance, and we could do that while still using this interpretable approach, but for this post I’m going to stick to just latitude and longitude.

Let’s use a LASSO regression to pull out a few of the more important features. Identifying a good cost for the LASSO turns out to be quite tricky: how many leaves do you want? How many is interpretable? How many are needed to maintain performance? It may be necessary to rely on a heuristic here, for example just using the top N leaves.

The first step, getting predictions from the leaves, looks a little messy, but most of this code is just one-hot encoding the leaves. The {recipes} package can also be used to one-hot data in far fewer lines of code.

xgb_fit <- fit(xgb_wf, df_train_xgb)

add_leaves <- function(xgb_fit, df) {
  n_trees <- xgb_fit %>% extract_fit_engine() %>% .$niter
  predict(xgb_fit, df, type = 'raw', opts = list(predleaf = T)) %>% 
    as_tibble(.name_repair = ~ paste0('tree_', 0:(n_trees - 1))) %>% 
    mutate(uid = df$uid, .before = tree_1) %>% 
    pivot_longer(starts_with('tree')) %>%
    transmute(uid, tree_leaf = paste0(name, '_X', value), value = 1) %>%
    pivot_wider(names_from = tree_leaf, values_from = value, 
                values_fill = 0) %>% 
    inner_join(df, by = 'uid')

df_split$data <- add_leaves(xgb_fit, df_split$data)
df_train_lr <- add_leaves(xgb_fit, df_train_lr)
df_test <- add_leaves(xgb_fit, df_test)

Now we can fit a linear model, either on all covariates, or just the leaves. The decision will again depend on the specific goal and domain. In this example, I’m most interested in discovering which combinations of latitude and longitude are most informative, so I will include just those.

df_train_lr_folds <- vfold_cv(df_train_lr, v = 10, strata = y)

leaf_rec <- recipe(df_train_lr) %>% 
  update_role(y, new_role = 'outcome') %>% 
  update_role(starts_with('tree_'), new_role = 'predictor') %>% 

leaf_spec <- logistic_reg(mode = 'classification', engine = 'glmnet',
                          mixture = 1, penalty = 0.1)

leaf_wf <- workflow() %>% 
  add_recipe(leaf_rec) %>% 

Let’s fit this LASSO on some cross-folds and get a performance estimate. The penalty here is relatively harsh sp we should expect the performance to be slightly worse than the XGBoost model. This is fine though, remember we just want to pull out the top embeddings to use in our final linear model.

leaf_cv <- fit_resamples(leaf_wf, resamples = df_train_lr_folds)

leaf_cv %>% 
  collect_metrics() %>% 
.metric .estimator mean n std_err .config
accuracy binary 0.7442199 10 0.0155847 Preprocessor1_Model1
roc_auc binary 0.8409025 10 0.0148693 Preprocessor1_Model1

As expected, a bit worse, but still not terrible considering we’re using two base features and regularizing out many of the XGBoost model’s trees. Let’s figure out which trees represent the best embeddings. Rather than using a single fit for this, let’s look at which leaves show up as important across cross-validation rounds. Note that the scoring mechanism used here is purely speculative, I hope that this will result in a more robust selection of embeddings, but this could be an interesting place to explore options.

leaf_importance <- map_dfr(df_train_lr_folds$splits, ~ {
  fit(leaf_wf, analysis(.x)) %>% 
    extract_fit_engine() %>% 
    vip::vi() %>% 
    filter(Importance >= 1)
}) %>% 
  group_by(Variable) %>% 
  summarise(mean = mean(Importance), 
            sd = sd(Importance),
            n = n(),
            score = n*((mean)/sd)) %>% 
  ungroup() %>% 
  filter(n >= 5) %>% 
  arrange(desc(score), Variable)

head(leaf_importance) %>% 
Variable mean sd n score
tree_58_X3 2.519602 0.2440208 10 103.25358
tree_95_X4 1.722802 0.2680110 10 64.28102
tree_97_X4 6.440888 1.0307758 10 62.48584
tree_48_X3 11.820779 1.8933194 10 62.43415
tree_57_X6 4.634435 0.7443771 10 62.25924
tree_89_X5 3.656289 0.7083346 10 51.61811

With the key leaves selected, we can move on to a final linear model that uses both these leaves and the other features available.

Final Linear Model

Specify and fit the final model.

best_leaves <- leaf_importance %>% 
  top_n(n = 10, wt = score) %>% 

flr_recipe <- recipe(df_train_lr) %>% 
  update_role(y, new_role = 'outcome') %>%
  update_role(hasSpa, garageSpaces, yearBuilt,
              numOfPatioAndPorchFeatures, lotSizeSqFt,
              avgSchoolRating, MedianStudentsPerTeacher,
              numOfBathrooms, numOfBedrooms,
              new_role = 'predictor') %>% 
  update_role(all_of(best_leaves), new_role = 'predictor')

flr_wf <- workflow() %>% 
  add_recipe(flr_recipe) %>% 

Let’s get final performance estimates and see if all of this was worth it.

fit_resamples(flr_wf, resamples = df_train_lr_folds) %>%
  collect_metrics() %>% 
.metric .estimator mean n std_err .config
accuracy binary 0.8009759 10 0.0086919 Preprocessor1_Model1
roc_auc binary 0.8768469 10 0.0126593 Preprocessor1_Model1

The result is a pretty decent performance bump! So far we’ve been totally ignoring test performance, let’s take a look at that as well, across all of the models.

  last_fit(lr_wf, df_split, metrics = mc_metrics) %>% 
    collect_metrics() %>% 
    transmute(model = 'logistic_regression', .metric, .estimate),
  last_fit(xgb_wf, df_split, metrics = mc_metrics) %>% 
    collect_metrics() %>% 
    transmute(model = 'xgboost_lat_lon', .metric, .estimate),
  last_fit(leaf_wf, df_split, metrics = mc_metrics) %>% 
    collect_metrics() %>% 
    transmute(model = 'lasso_leaf', .metric, .estimate),  
  last_fit(flr_wf, df_split, metrics = mc_metrics) %>% 
    collect_metrics() %>% 
    transmute(model = 'final_logistic_regression', .metric, .estimate)
) %>% 
  pivot_wider(names_from = .metric, values_from = .estimate) %>% 
model accuracy roc_auc
logistic_regression 0.7824 0.8519190
xgboost_lat_lon 0.8080 0.8897721
lasso_leaf 0.7548 0.8500604
final_logistic_regression 0.8166 0.8904807

It’s not amazing, but that’s a nice performance bump. I’m positive a bit of fine-tuning (and better selection of leaves) could result in improved performance.

Interpreting the Leaves

So far, we’ve acted like just having the embeddings in a linear model means that this model is interpretable. But what are the leaves actually representing? Fortunately, we fit XGBoost with fairly shallow trees, so it should be pretty easy to actually inspect them!

Here’s a function that can navigate a tree and summarize each leaf.

summarize_leaves <- function(tree, id, sofar) {
  row <- tree %>% filter(ID == id)
  if (row$Feature == 'Leaf') {
    ret <- list()
    ret[[row$ID]] <- sofar
  } else {
    left_sofar = c(sofar, glue("{row$Feature}<{row$Split}"))
    left_tree = summarize_leaves(tree, row$Yes, left_sofar)
    right_sofar = c(sofar, glue("{row$Feature}>={row$Split}"))
    right_tree = summarize_leaves(tree, row$No, right_sofar)
    return(c(left_tree, right_tree))

Let’s get the XGBoost tree and then extract the leaves of interest.

xgb_tree <- xgboost::xgb.model.dt.tree(model = xgb_fit %>% extract_fit_engine())

trees <- best_leaves %>% 
  str_extract('^tree_[0-9]*') %>% 
leaves <- best_leaves %>% 
  str_extract('X[0-9]*') %>%

best_embeddings <- map_dfr(1:length(trees), function(index) {
  toi <- trees[index]
  loi <- leaves[index]
  leaf_label <- glue('{toi}-{loi}')
    tree = toi,
    leaf = loi,
    label = leaf_label,
    emb = xgb_tree %>% 
      filter(Tree == toi) %>% 
      summarize_leaves(glue('{toi}-0'), c()) %>% 

head(best_embeddings) %>% 
tree leaf label emb
58 3 58-3 latitude<30.4451027
58 3 58-3 longitude<-97.6950073
95 4 95-4 latitude<30.4164162
95 4 95-4 latitude>=30.3882008
97 4 97-4 latitude<30.4989376
97 4 97-4 latitude>=30.4512882

Let’s take this just one step further and plot these on top of our data! We’ll need to do a little processing of the embeddings to plot them in a palatable way. Specifically, the max/min of the lat/lon (depending on the direction of the cutoff) is what ultimately determines how the embedding works. As such, these are what I’ll focus on plotting.

best_embeddings <- best_embeddings %>% 
  mutate(direction = ifelse(str_detect(emb, ">="), ">=", "<")) %>% 
  separate(emb, c('variable', 'cutpoint'), "(\\>\\=|\\<)", convert = T) %>% 
  distinct() %>% 
  group_by(label, variable, direction) %>% 
    direction == ">=" ~ cutpoint == max(cutpoint),
    direction == "<" ~ cutpoint == min(cutpoint),

ggplot(data = best_embeddings) +
  geom_point(data = df, aes(latitude, longitude, colour = y),
             alpha = 0.1) +
  geom_vline(data = . %>% filter(variable == 'latitude'),
             aes(xintercept = cutpoint, colour = direction),
             size = 1) +
  geom_hline(data = . %>% filter(variable == 'longitude'),
             aes(yintercept = cutpoint, colour = direction),
             size = 1) +
  facet_wrap(vars(label), ncol = 5) +
  scale_color_few() +
  labs(colour = '') +
  theme_minimal() +
  theme(legend.position = 'bottom',
        axis.text = element_blank())

From this we can see the regions that our data is identifying. They’re clearly not perfect, but they definitely do identify some regional cutoffs. How many of these embeddings apply to each user?

df_train_lr %>% 
  select(uid, best_leaves) %>% 
  pivot_longer(starts_with('tree')) %>% 
  group_by(uid) %>% 
  summarise(number_of_active_embeddings = sum(value)) %>% 
  ungroup() %>% 
  count(number_of_active_embeddings) %>% 
number_of_active_embeddings n
0 117
1 428
2 223
3 210
4 216
5 57

Given that the I’m using very simple trees, this makes a lot of sense to me. I’d argue that these are still interpretable. Checking the number of leaves applying to a given observation may also be a good way to check for overfitting.

This concludes my example. I think a lot more optimization could be done to improve this model, but I’m quite happy with it as a first pass.

Related Ideas

This post is already long, at least by my standards, but I just want to hit on a few other very neat things this approach allows.

Missing Data

XGBoost can still be used with missing data, the algorithm simply lumps all observations with missing for the selected variable into one side of the decision tree. As such, this approach could be used to create embeddings of columns with missing data, the resulting embeddings can then be used in models which do not typically support missing data.

Other Tree Based Models

There’s nothing special about XGBoost, any tree-based model could be used in its place. I suspect, this could be advantageous in some situations.


Rather than tuning the XGBoost hyperparameters and then using the leaves in a subsequent model, we could bundle all steps into a hyperparameter tuning loop. This would allow us to tune XGBoost specifically to create good embeddings. However, I’m not sure if this will result in practically better performance.


While not perfect, the use of a tree-based model to create features that are interpretable and usable in linear models is something I find really exciting. Specifically because this approach can combine the best-in-class performance of decision forests with the interpretability of linear models that is often required in regulated environments. This is doubtless an area I will continue to explore.

Please feel free to reach out with any questions or corrections!

Head, Tim. 2021. “Feature Transformations with Ensembles of Trees.” https://scikit-learn.org/stable/auto_examples/ensemble/plot_feature_transformation.html.
Loft. 2021. “Xgbse: XGBoost Survival Embeddings.” https://loft-br.github.io/xgboost-survival-embeddings/index.html.
Rudin, Cynthia. 2019. “Stop Explaining Black Box Machine Learning Models for High Stakes Decisions and Use Interpretable Models Instead.” https://arxiv.org/abs/1811.10154.
Slack, Dylan, Sophie Hilgard, Emily Jia, Sameer Singh, and Himabindu Lakkaraju. 2020. “Fooling LIME and SHAP: Adversarial Attacks on Post Hoc Explanation Methods.” https://arxiv.org/abs/1911.02508.



For attribution, please cite this work as

Ewing (2022, Jan. 31). Improper Prior | Ben Ewing: XGBoost as an Interaction Explorer. Retrieved from https://improperprior.com/posts/2022-01-31-xgboost-as-a-feature-generator/

BibTeX citation

  author = {Ewing, Ben},
  title = {Improper Prior | Ben Ewing: XGBoost as an Interaction Explorer},
  url = {https://improperprior.com/posts/2022-01-31-xgboost-as-a-feature-generator/},
  year = {2022}