13.1 Stepwise Functions#

The Dataset#

We will use the Mid-Atlantic Wage Dataset from the ISLP library to showcase fitting stepwise functions. Our research goal is to predict wage for different age ranges by taking the average wage within each bin as our estimate for prediction.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
import patsy
from ISLP import load_data

# Load and plot the data
df = load_data('Wage')

fig, ax = plt.subplots(figsize=(8,5))
sns.scatterplot(data=df, x='age', y='wage', alpha=0.4, ax=ax)       
ax.set_title("ISLP data: Wage distribution across age");
../../../_images/78b5ab401522a2fbdda331157c6762151ee1a118976742d70a0d8c55aff8b02b.png

Fitting a stepwise function#

To fit a stepwise funtion, we first need to get cut points for our predictor variable age. For example, we can decide to split our data into 4 equal parts:

bins = pd.cut(df['age'], 4)
print(bins)
0       (17.938, 33.5]
1       (17.938, 33.5]
2         (33.5, 49.0]
3         (33.5, 49.0]
4         (49.0, 64.5]
             ...      
2995      (33.5, 49.0]
2996    (17.938, 33.5]
2997    (17.938, 33.5]
2998    (17.938, 33.5]
2999      (49.0, 64.5]
Name: age, Length: 3000, dtype: category
Categories (4, interval[float64, right]): [(17.938, 33.5] < (33.5, 49.0] < (49.0, 64.5] < (64.5, 80.0]]

The output provides us with intervals for evenly sized bins of age. We can use the upper value of each interval to get cut points for our stepwise function. We then use the dmatrix function to create a design matrix:

transformed_age = patsy.dmatrix("bs(age, knots=(33.5, 49, 64.5), degree=0)",
                                data={"age": df['age']},
                                return_type='dataframe')

When you specify "bs(age, ...)", you tell dmatrix to transform age using B-spline basis functions. These basis functions segment the variable age into piecewise polynomials defined by the specified knots and degree. When you set degree to zero, it creates piecewise constant functions, meaning the resulting function is a step function consisting of only horizontal segments.

Fitting the model#

Once we have the age in our design matrix, we can fit the model like a normal categorical regression model:

model = sm.OLS(df['wage'], transformed_age)
model_fit = model.fit()

print(model_fit.summary())
                            OLS Regression Results                            
==============================================================================
Dep. Variable:                   wage   R-squared:                       0.062
Model:                            OLS   Adj. R-squared:                  0.062
Method:                 Least Squares   F-statistic:                     66.56
Date:                Wed, 09 Apr 2025   Prob (F-statistic):           1.16e-41
Time:                        08:53:29   Log-Likelihood:                -15353.
No. Observations:                3000   AIC:                         3.071e+04
Df Residuals:                    2996   BIC:                         3.074e+04
Df Model:                           3                                         
Covariance Type:            nonrobust                                         
================================================================================================================
                                                   coef    std err          t      P>|t|      [0.025      0.975]
----------------------------------------------------------------------------------------------------------------
Intercept                                       94.1584      1.476     63.789      0.000      91.264      97.053
bs(age, knots=(33.5, 49, 64.5), degree=0)[0]    23.9331      1.849     12.941      0.000      20.307      27.560
bs(age, knots=(33.5, 49, 64.5), degree=0)[1]    23.8857      2.019     11.833      0.000      19.928      27.844
bs(age, knots=(33.5, 49, 64.5), degree=0)[2]     7.6406      4.987      1.532      0.126      -2.139      17.420
==============================================================================
Omnibus:                     1062.290   Durbin-Watson:                   1.965
Prob(Omnibus):                  0.000   Jarque-Bera (JB):             4549.991
Skew:                           1.681   Prob(JB):                         0.00
Kurtosis:                       8.010   Cond. No.                         7.84
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.

The summary tells us:

  • The average wage in bin 1 within the age range from 17.9 to 33.5 years equals to $94.16k per year.

  • The wage difference in bin 2 (i.e., between 33.5 and 49 years) as compared with bin 1 equals to $23.93k per year.

  • The wage difference in bin 3 (i.e., between 49 and 64.5 years) as compared with bin 1 equals to $23.89k per year.

  • The wage difference in bin 4 (i.e., between 64.5 and 80.1 years) as compared with bin 1 equals to $7.64k per year.

The second and the third bin differ significantly from the first bin. The third bin does not differ significantly from the first one (p = .126)

In summary, the model suggests that wages vary with age, but the relationship is not the same across all age groups. Wages appear to increase with age up to a point, but the increase is not statistically significant for the oldest age group in this sample.

Plotting the model#

We can also plot the model. Note that bin 2 and 3 have very similar estimates which leads to an indistinguishable difference between \(Y\) values in the second bin and \(Y\) values in the third.

# Plot the model
fig, ax = plt.subplots(figsize=(8,5))

# Generate age values for predictions and transform using the spline basis
xp = np.linspace(df['age'].min(), df['age'].max(), 100)
xp_trans = patsy.dmatrix("bs(xp, knots=(33.5, 49, 64.5), degree=0)",
                         data={"xp": xp},
                         return_type='dataframe')

# Use model fitted before to predict wages for generated age values
predictions = model_fit.predict(xp_trans)

# Plot the original data and the model fit
sns.scatterplot(data=df, x="age", y="wage", alpha=0.4, ax=ax)
ax.axvline(33.5, linestyle='--', alpha=0.4, color="black")      # Cut point 1
ax.axvline(49, linestyle='--', alpha=0.4, color="black")        # Cut point 2
ax.axvline(64.5, linestyle='--', alpha=0.4, color="black")      # Cut point 3
ax.plot(xp, predictions, color='red')
ax.set_title("ISLP data: stepwise fit (zero-order)");
../../../_images/4f0406e509f08c0d32d8ac59a4e254182681f62f4258a7adc7cbe96ec21ea8d0.png