Generalized Additive Models

Generalized Additive Models#

Generalized Additive Models (GAMs) offer a powerful and flexible extension to traditional linear models by allowing for non-linear, additive relationships between each predictor and the outcome. Unlike standard linear regression, which assumes a strictly linear association between predictors and the response variable, GAMs replace each linear term with a smooth function, enabling the model to better capture complex patterns in the data. Thanks to their additive structure, each predictor contributes independently to the model, making it easy to interpret the effect of each variable.

So instead of using the standard linear function

\[ y = b_0 + b_1 x_1 + b_2 x_2 + ... + b_i x_i + \epsilon \]

we have something like this:

\[ y = b_0 + f_1(x_1) + f_2(x_2) + ... + f_i(x_i) + \epsilon \]

In short, instead of using fixed slope coefficients \(b_i\)​ that assume a straight-line relationship, we replace them with flexible (possibly non-linear) smooth functions \(f_i\)​ for each predictor. These functions can be anything from constant functions and polynomials to wavelets.

So in comparison to simple splines regression which (last semester) was introduced to predict \(y\) on the basis of a single predictor \(x\), GAMs are a generalization which allows us to predict \(y\) given multiple predictors \(x_1 ... x_p\).

Today’s data: The Diabetes dataset#

For today’s practical demonstration, we will work with the Diabetes dataset from scikit-learn. This dataset contains medical information collected from 442 diabetes patients, including:

  • Features: 10 baseline measures from the beginning of the study: age, sex, Body Mass Index (BMI), average blood pressure, as well as six blood serum measurements (e.g. cholesterol, blood sugar, etc.)

  • Target: A quantitative measure of disease progression one year after the baseline measurements were taken.

You can find more information here.

from sklearn import datasets
from sklearn.model_selection import train_test_split

# Get data
diabetes = datasets.load_diabetes(as_frame=True)
X = diabetes.data
y = diabetes.target

# Split the data into training and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

To explore the relationships between each feature and the target, we plot each predictor against the disease progression outcome. These scatter plots with simple linear regressions help us to visually assess whether the relationship between a feature and the target is linear or if we need a more flexible model approach:

../../_images/f8b3155e7a85d5aa2e6a35ea1e1b344d383acbb1b00646c672cb476e1582bd3e.png

Although the linear regression fits seem to be reasonable, we might suspect that a more flexible approach could be beneficial, so let’s try it!

GAMs in Python#

There are multiple options for implementing GAMs. We will here use statsmodels, as you should already be familiar with it from the previouis semester. The workflow is the following:

  1. Separate smooth (continuous) and categorical features

  2. Create spline basis functions for the continuous features using B-splines

  3. Fit the GAM with both the smooth and categorical predictors

from statsmodels.gam.api import GLMGam, BSplines

spline_features = ['age', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']
categorical_features = ['sex']

# Create smoother for continuous variables
bs = BSplines(X_train[spline_features], df=[6]*len(spline_features), degree=[3]*len(spline_features))

# Fit GAM with smoother and exog for categorical
gam = GLMGam(y_train, exog=X_train[categorical_features], smoother=bs)
res = gam.fit()

print(res.summary())
                 Generalized Linear Model Regression Results                  
==============================================================================
Dep. Variable:                 target   No. Observations:                  309
Model:                         GLMGam   Df Residuals:                   263.00
Model Family:                Gaussian   Df Model:                        45.00
Link Function:               Identity   Scale:                          2998.6
Method:                         PIRLS   Log-Likelihood:                -1650.5
Date:                Mon, 27 Oct 2025   Deviance:                   7.8864e+05
Time:                        13:07:03   Pearson chi2:                 7.89e+05
No. Iterations:                     3   Pseudo R-squ. (CS):             0.7023
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------------------------
sex         -289.6664     78.808     -3.676      0.000    -444.126    -135.206
age_s0       -32.4246     50.442     -0.643      0.520    -131.289      66.439
age_s1       -61.4878     28.887     -2.129      0.033    -118.105      -4.871
age_s2        -5.3132     39.461     -0.135      0.893     -82.655      72.029
age_s3       -42.6463     38.246     -1.115      0.265    -117.608      32.315
age_s4        18.0807     60.041      0.301      0.763     -99.598     135.759
bmi_s0        43.8396     54.386      0.806      0.420     -62.756     150.435
bmi_s1        19.4553     33.130      0.587      0.557     -45.479      84.390
bmi_s2        97.0613     48.512      2.001      0.045       1.980     192.142
bmi_s3        57.8888     49.412      1.172      0.241     -38.956     154.734
bmi_s4       252.4823     59.944      4.212      0.000     134.995     369.969
bp_s0         30.6097     58.249      0.525      0.599     -83.556     144.775
bp_s1         38.3179     38.098      1.006      0.315     -36.353     112.989
bp_s2         76.9274     49.565      1.552      0.121     -20.218     174.073
bp_s3         81.3546     48.055      1.693      0.090     -12.832     175.542
bp_s4        147.7945     57.513      2.570      0.010      35.070     260.519
s1_s0         58.6364    106.319      0.552      0.581    -149.744     267.017
s1_s1       -125.9137    154.677     -0.814      0.416    -429.075     177.247
s1_s2       -233.7171    336.444     -0.695      0.487    -893.135     425.701
s1_s3       -439.3612    450.331     -0.976      0.329   -1321.994     443.272
s1_s4       -351.3655    556.288     -0.632      0.528   -1441.670     738.939
s2_s0         26.6938     99.713      0.268      0.789    -168.741     222.128
s2_s1        125.9591    143.712      0.876      0.381    -155.712     407.630
s2_s2        274.4660    332.759      0.825      0.409    -377.729     926.661
s2_s3        464.5552    472.107      0.984      0.325    -460.758    1389.868
s2_s4        262.7111    613.544      0.428      0.669    -939.813    1465.235
s3_s0         14.3698     64.625      0.222      0.824    -112.292     141.032
s3_s1         40.7103     49.233      0.827      0.408     -55.784     137.205
s3_s2         39.5808     95.906      0.413      0.680    -148.392     227.554
s3_s3        101.7901    148.474      0.686      0.493    -189.214     392.794
s3_s4        120.5364    187.063      0.644      0.519    -246.101     487.174
s4_s0        -37.5731     46.372     -0.810      0.418    -128.461      53.315
s4_s1         11.2719     36.681      0.307      0.759     -60.622      83.166
s4_s2        -64.0818     65.922     -0.972      0.331    -193.286      65.123
s4_s3         48.0156     87.235      0.550      0.582    -122.961     218.992
s4_s4         25.2745    109.679      0.230      0.818    -189.693     240.242
s5_s0        -22.0863     49.386     -0.447      0.655    -118.882      74.709
s5_s1         26.4723     30.680      0.863      0.388     -33.660      86.605
s5_s2        127.4137     69.308      1.838      0.066      -8.428     263.255
s5_s3        174.7068    139.025      1.257      0.209     -97.776     447.190
s5_s4        264.2489    238.234      1.109      0.267    -202.681     731.179
s6_s0        -72.3705     79.891     -0.906      0.365    -228.954      84.213
s6_s1        -33.7499     51.643     -0.654      0.513    -134.969      67.469
s6_s2        -74.4075     61.859     -1.203      0.229    -195.648      46.833
s6_s3         -6.0775     59.770     -0.102      0.919    -123.224     111.069
s6_s4        -28.6995     62.161     -0.462      0.644    -150.532      93.133
==============================================================================

For the B-splines1 we choose:

  • df=[6]*len(spline_features) -> 6 basis functions per feature

  • degree=[3]*len(spline_features) -> cubic splines (degree 3)

The output includes parameter estimates for all spline basis functions and categorical variables:

  • The coef column shows the estimated effect.

  • The P>|z| column tells you whether the estimate is statistically significant

  • The Pseudo R-squared (CS) gives a rough measure of model fit (here around 0.70)

Due to the additive nature of GAMs, we can isolate and visualise the effect of each smooth term individually. This helps us understand the relationship between each predictor and the response, controlling for all other variables:

import matplotlib.pyplot as plt
fig, ax = plt.subplots(3,3, figsize=(7,7))

for i, feature in enumerate(spline_features):
    res.plot_partial(i, cpr=True, ax=ax[i//3, i%3])
    ax[i//3, i%3].set_title(f"Partial Effect: {feature}")

plt.tight_layout()
../../_images/384298e50945b7ea1fa0255df1af8b953ca8351df77d4ed8ee9342d6b2bce0ef.png

Each subplot shows how the modelled relationship between a feature and the target behaves nonlinearly. The cpr=True option adds confidence intervals around the estimated smooth curve.

Summary

  • GAMs allow flexible, interpretable models where you don’t assume linearity for every predictor.

  • statsmodels makes it easy to combine smooth terms (B-splines or alternatively Cyclic Cubic Splines) with categorical or linear predictors.

  • You can inspect smooth effects with .plot_partial(), and linear terms directly from the model summary.