Multivariate Adaptive Regression Splines in a Nutshell
Introduction
This post introduces multivariate adaptive regression splines (MARS). The focus of this post is to explain the algorithm in a regression context1, and some background knowledge on stepwise linear regression is necessary.
1 MARS can also incorporate logistic regression to predict probabilities in a classification context.
The Building Blocks
Like standard linear regression, MARS uses the ordinary least squares (OLS) method to estimate the coefficient of each term. However, instead of an original predictor, each term in a MARS model is a basis function derived from original predictors. A basis function takes one of the following forms:
- a constant 1, which represents the intercept;
- an original predictor from the data set;
- a hinge function (see below) derived from a predictor;
- a product of two or more hinge functions derived from different predictors, which captures the interaction between/among the predictors.
MARS does not treat categorical predictors differently from standard linear regression. However, for each of all combinations of continuous predictors and their observed values, MARS creates a reflected pair of hinge functions in the form
\[ \{max(0, x - c), max(0, c - x)\} \quad (1) \]
where \(x\) denotes a continuous predictor; \(c\) denotes an observed value of that predictor which is referred to as a knot. Hinge functions are piecewise linear and each is used to model an isolated portion of the original data. For example, \(max(0, x - c)\), which is symmetrical to \(max(0, c - x)\), is a linear function of \(x\) when \(x > c\) but remains constantly at zero otherwise. Suppose there are \(p\) continuous predictors and they all have \(n\) distinct values, then there will be \(np\) such pairs of hinge functions. The model-training process will iteratively select and add some of them into the model (see the next section).
Let \(h_m(X)\) (\(m \in \{0, 1, ..., M\}\)) denote the basis functions2, then a MARS model can be written as
2 Since there may be multiple predictors in a basis function, \(h_m(X)\) should be considered as a function over the entire predictor space.
\[ f(X) = \sum_{m = 0}^{M} \beta_m h_m(X) \]
where \(h_0(X) = 1\) is the only constant basis function.
The Training Process
The training process of MARS is similar to a forward stepwise linear regression: at each step, MARS selects new terms into the model that minimize the sum of squared error using OLS. However, the construction of new terms by MARS is more sophisticated and deserves further description.
Initially, the basis function \(h_0(X) = 1\) is added to the model and the result is a model with an intercept term \(\beta_0\). At each subsequent step, an original predictor or a reflected pair of hinge functions are selected and added to the model. The selected pair of hinge functions (or original predictor) can enter the model directly; alternatively, they can be multiplied by an existing basis function that is already in the model (candidate functions are excluded) and become new basis functions. The second case allows the interaction between/among different predictors to be modeled. Note that a reflected pair of hinge functions always enter the model together (but may be removed separately in the pruning process; see the next section). The training process goes on until it meets one of many condition such as: 1. maximum number of model terms before pruning; 2. forward stepping threshold measured by \(R^2\).
The Pruning Process
Although there are other methods, MARS typically applies a backward deletion procedure to prune the model. At each step, the algorithm removes a term in the model that results in the smallest increase in the sum of squared error, obtaining an optimal model \(\hat{f}\) at each size \(\lambda\). The final model can be determined using cross-validation (CV), but generalized cross-validation (GCV) may be preferred since it is much more computationally efficient.
Generalization error given by GCV is defined as
\[ GCV(\lambda) = \frac{\sum^N_{i = 1}(y_i - \hat{f}_\lambda(x_i))^2}{(1 - \frac{M'(\lambda)}{N})^2} \]
where \(M'(\lambda)\) is the effective number of parameters in the model. \(M'(\lambda)\) is in turn defined as
\[ M(\lambda) = M + cK \]
where \(M\) is the number of terms in the model, \(K\) is the number of knots in the model, \(c = 2\) if the model does not involve interaction terms and \(c = 3\) otherwise3.
3 According to Hastie, Tibshirani, and Friedman (2009), these values are suggested by some mathematical and simulation results.
The Parameters
The earth
package in R provides an implementation of MARS, on which caret::train()
with method = earth
is based. After conducting a simulation (see appendix), I found the following parameters to be particularly worth noting (please refer to the help page and associated documentation for more detailed information):
- Parameters for the training process:
degree
: Maximum degree of interaction.nk
: Maximum number of model terms before pruning.thresh
: Forward stepping threshold.minspan
: Minimum number of observations between knots.endspan
: Minimum number of observations before the first and after the final knot.linpreds
: Index vector specifying which predictors should enter linearly.allowed
: Function specifying which predictors can interact and how.newvar.penalty
: Penalty for adding a new variable in the forward pass.
- Parameters for the pruning process:
nprune
: Maximum number of terms (including intercept) in the pruned model.nfold
: Number of cross-validation folds (if CV is used to prune the model instead of GCV).pmethod
: Pruning method.
The caret::train()
function considers degree
and nprune
to be the only major hyperparameters that need to be tuned. However, my experience with the simulation indicate that, if the underlying pattern in training data is complicated, nk
should be set with a bigger value and thresh
a smaller value so that a more flexible model can be obtained from the training process; additionally, if the sample size is relatively small, it may also be necessary to use smaller values of minspan
and endspan
.
One thing I noticed about the training function provided by the earth
package (i.e., earth::earth()
) is that it performs model selection by default and gives the best pruned model as the end result. This is indicated by the nprune
parameters, which specifies the maximum instead of exact number of terms in the final model. This feature can be turned off by specifying penalty = -1
in the training function. This special value of penalty
causes earth::earth()
to set the GCV to \(RSS/nrow(x)\). Since the RSS on the training set always decreases with more terms, the pruning pass will choose the maximum allowable number of terms.
Appendix: a Simulation
A simulation can be conducted to show how different values of nk
, thresh
, minspan
and endspan
affects the model-training process. You can run the code below in R to see the results for yourself.
The results indicate that the final model is more flexible if: nk
is bigger; thresh
is smaller; minspan
and endspan
are smaller. In this simulation, the model that best captures the underlying pattern is the one with nk = 20
, thresh = 0
and minspan = endspan = 1
(i.e., the most flexible one which also contains the most number of terms).
A word of caution: this simulation does not imply a more flexible MARS model is always more accurate in terms of test error. It works well in this case perhaps because the random noise is relatively small and there are no correlated or redundant predictors. Additionally, a critical strength of MARS is that it can train models that are very interpretable4; however, increasing the flexibility generally reduces the interpretability. If interpretability is not a key consideration, then perhaps a more flexible algorithm such as random forest should be used instead.
4 This is the reason why the degree of interaction is usually limited to one or two but rarely above
library(tidyverse)
library(earth)
# generate a training set
set.seed(1)
<- 100
sample_size <- tibble(
dat x = runif(sample_size, min = -3.5 * pi, max = 3.5 * pi),
e = rnorm(sample_size, sd = 0.3),
y = sin(x) + e
)
# the underlying pattern of the data set
ggplot(dat %>% mutate(y = sin(x))) +
geom_line(aes(x = x, y = y))
# train MARS models with different values of nk, thresh and span (including minspan and endspan)
<- crossing(nk = c(5, 10, 20),
results thresh = c(0, 0.01, 0.1),
span = c(1, 5, 10)) %>%
pmap(function(nk, thresh, span, dat) {
<- earth(y ~ x, data = dat,
fit degree = 1, nprune = NULL,
nk = nk, thresh = thresh, minspan = span, endspan = span)
mutate(dat, nk = nk, thresh = thresh, span = span,
y_predicted = predict(fit)[,1], p = length(coef(fit)))
dat = dat) %>%
}, bind_rows()
# predictions given by different models
ggplot(results) +
geom_point(aes(x = x, y = y_predicted)) +
facet_grid(nk ~ thresh + span, labeller = label_both)
# the number of terms included in each model with different parameters (the output is attached below)
%>%
results select(nk, thresh, span, p) %>%
distinct() %>%
mutate(nk = as.factor(nk) %>% fct_relabel(function(x) {paste0("nk=", x)})) %>%
split(.$nk) %>%
map(function(x) {
%>%
x select(-nk) %>%
spread(key = span, value = p, sep = "=") %>%
as.data.frame() %>%
`rownames<-`(paste0("thresh=", .$thresh)) %>%
select(-thresh)
})
#> $`nk=5`
#> span=1 span=5 span=10
#> thresh=0 4 4 2
#> thresh=0.01 4 4 2
#> thresh=0.1 1 1 1
#>
#> $`nk=10`
#> span=1 span=5 span=10
#> thresh=0 5 5 4
#> thresh=0.01 5 5 4
#> thresh=0.1 1 1 1
#>
#> $`nk=20`
#> span=1 span=5 span=10
#> thresh=0 8 7 4
#> thresh=0.01 5 7 4
#> thresh=0.1 1 1 1