Overfitting
Contents
Overfitting¶
Announcements¶
I’ll post a sample midterm soon (Saturday at the latest). The midterm is Thursday Week 8 (May 19).
Be sure to attend tomorrow’s Discussion Section:
Yasmeen will introduce the dataset for the homework and help with the first few homework questions.
Note cards for the midterm will be handed out.
import pandas as pd
import altair as alt
Simulated data¶
We will work with the same simulated dataset as on Monday. The true underlying function is of the form \(f(x) = c_2 x^2 + c_1 x + c_0\). The true outputs are stored in the “y_true” column. We’ve hidden this true data by adding some random noise to each output and put the result in the “y” column.
df = pd.read_csv("../data/sim_data.csv")
df.head()
x | y_true | y | x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -3.329208 | -18.207589 | -117.484900 | -3.329208 | 11.083626 | -36.899694 | 122.846756 | -408.982395 | 1361.587441 | -4533.007730 | 1.509133e+04 | -5.024216e+04 | 1.672666e+05 |
1 | 6.465018 | 74.160562 | 73.954907 | 6.465018 | 41.796463 | 270.214901 | 1746.944309 | 11294.027098 | 73016.092970 | 472050.384357 | 3.051814e+06 | 1.973004e+07 | 1.275550e+08 |
2 | -4.478046 | -7.670062 | -13.810089 | -4.478046 | 20.052899 | -89.797810 | 402.118751 | -1800.706392 | 8063.646628 | -36109.383086 | 1.616995e+05 | -7.240978e+05 | 3.242544e+06 |
3 | 2.043272 | -7.925152 | 19.461182 | 2.043272 | 4.174960 | 8.530580 | 17.430295 | 35.614834 | 72.770792 | 148.690523 | 3.038152e+02 | 6.207771e+02 | 1.268416e+03 |
4 | 4.850593 | 36.485466 | 22.375230 | 4.850593 | 23.528255 | 114.125996 | 553.578791 | 2685.185564 | 13024.743051 | 63177.731115 | 3.064495e+05 | 1.486462e+06 | 7.210222e+06 |
max_deg = 10
cols = [f"x{i}" for i in range(1, max_deg+1)]
c_true = alt.Chart(df).mark_circle(color="black").encode(
x="x",
y="y_true",
tooltip=["x", "y_true", "y"]
)
c = alt.Chart(df).mark_circle().encode(
x="x",
y="y",
tooltip=["x", "y_true", "y"]
)
c_true
The black points represent the true underlying model, while the blue points represent the randomized version. Real-world data is almost always closer to the randomized version.
c+c_true
Dividing into a training set and a test set¶
The motto is to fit on a training set, and predict or evaluate on a test set.
from sklearn.model_selection import train_test_split
cols
['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10']
We use random_state
to guarantee we all get the same results. This guarantees that the data does exhibit the aspects that I want.
X_train, X_test, y_train, y_test = train_test_split(df[cols], df["y"], train_size=0.24, random_state=6)
df[cols].shape
(50, 10)
The X_train
variable will hold 24% of the rows from df[cols]
.
X_train.shape
(12, 10)
0.24*50
12.0
X_train
x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | |
---|---|---|---|---|---|---|---|---|---|---|
13 | -8.982275 | 80.681271 | -724.701394 | 6509.467488 | -58469.829551 | 525192.110405 | -4.717420e+06 | 4.237317e+07 | -3.806075e+08 | 3.418721e+09 |
11 | 6.449735 | 41.599085 | 268.303082 | 1730.483838 | 11161.162581 | 71986.543546 | 4.642941e+05 | 2.994574e+06 | 1.931421e+07 | 1.245716e+08 |
1 | 6.465018 | 41.796463 | 270.214901 | 1746.944309 | 11294.027098 | 73016.092970 | 4.720504e+05 | 3.051814e+06 | 1.973004e+07 | 1.275550e+08 |
25 | 0.371593 | 0.138081 | 0.051310 | 0.019066 | 0.007085 | 0.002633 | 9.782985e-04 | 3.635287e-04 | 1.350847e-04 | 5.019650e-05 |
16 | -3.942982 | 15.547104 | -61.301947 | 241.712453 | -953.067768 | 3757.928725 | -1.481744e+04 | 5.842491e+04 | -2.303683e+05 | 9.083382e+05 |
45 | -4.457650 | 19.870645 | -88.576385 | 394.842539 | -1760.069919 | 7845.776000 | -3.497372e+04 | 1.559006e+05 | -6.949505e+05 | 3.097846e+06 |
15 | 5.481667 | 30.048674 | 164.716823 | 902.922780 | 4949.522041 | 27131.631828 | 1.487266e+05 | 8.152695e+05 | 4.469036e+06 | 2.449777e+07 |
42 | -9.804880 | 96.135668 | -942.598671 | 9242.066675 | -90617.352934 | 888492.254125 | -8.711560e+06 | 8.541580e+07 | -8.374916e+08 | 8.211505e+09 |
20 | -0.690041 | 0.476157 | -0.328568 | 0.226726 | -0.156450 | 0.107957 | -7.449486e-02 | 5.140454e-02 | -3.547126e-02 | 2.447664e-02 |
35 | 4.852230 | 23.544134 | 114.241552 | 554.326269 | 2689.718466 | 13051.132219 | 6.332709e+04 | 3.072776e+05 | 1.490982e+06 | 7.234585e+06 |
9 | 4.556122 | 20.758251 | 94.577131 | 430.904978 | 1963.255801 | 8944.833631 | 4.075376e+04 | 1.856791e+05 | 8.459767e+05 | 3.854373e+06 |
10 | -6.299257 | 39.680640 | -249.958554 | 1574.553197 | -9918.515413 | 62479.278687 | -3.935730e+05 | 2.479218e+06 | -1.561723e+07 | 9.837695e+07 |
The X_test
variable will hold the remaining rows.
X_test
x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | |
---|---|---|---|---|---|---|---|---|---|---|
49 | 6.068927 | 36.831870 | 223.529915 | 1356.586642 | 8233.024737 | 49965.622715 | 3.032377e+05 | 1.840327e+06 | 1.116881e+07 | 6.778270e+07 |
40 | -0.364001 | 0.132497 | -0.048229 | 0.017555 | -0.006390 | 0.002326 | -8.466826e-04 | 3.081936e-04 | -1.121829e-04 | 4.083471e-05 |
38 | -1.602675 | 2.568568 | -4.116580 | 6.597541 | -10.573716 | 16.946233 | -2.715931e+01 | 4.352755e+01 | -6.976053e+01 | 1.118035e+02 |
23 | 4.441058 | 19.722998 | 87.590980 | 388.996638 | 1727.556699 | 7672.179799 | 3.407260e+04 | 1.513184e+05 | 6.720137e+05 | 2.984452e+06 |
7 | -5.656363 | 31.994443 | -180.972185 | 1023.644381 | -5790.104257 | 32750.931784 | -1.852512e+05 | 1.047848e+06 | -5.927008e+06 | 3.352531e+07 |
0 | -3.329208 | 11.083626 | -36.899694 | 122.846756 | -408.982395 | 1361.587441 | -4.533008e+03 | 1.509133e+04 | -5.024216e+04 | 1.672666e+05 |
6 | -6.557409 | 42.999609 | -281.966007 | 1848.966339 | -12124.427923 | 79504.828902 | -5.213457e+05 | 3.418677e+06 | -2.241766e+07 | 1.470018e+08 |
34 | -1.384337 | 1.916390 | -2.652930 | 3.672550 | -5.084049 | 7.038039 | -9.743020e+00 | 1.348763e+01 | -1.867143e+01 | 2.584755e+01 |
14 | -9.026659 | 81.480574 | -735.497357 | 6639.083877 | -59928.746571 | 540956.362638 | -4.883029e+06 | 4.407743e+07 | -3.978720e+08 | 3.591455e+09 |
31 | 6.893150 | 47.515521 | 327.531630 | 2257.724760 | 15562.836137 | 107276.968808 | 7.394763e+05 | 5.097321e+06 | 3.513660e+07 | 2.422019e+08 |
48 | 3.778329 | 14.275771 | 53.938563 | 203.797646 | 770.014590 | 2909.368585 | 1.099255e+04 | 4.153348e+04 | 1.569272e+05 | 5.929225e+05 |
24 | -6.563454 | 43.078928 | -282.746563 | 1855.794057 | -12180.418898 | 79945.618965 | -5.247194e+05 | 3.443972e+06 | -2.260435e+07 | 1.483626e+08 |
19 | -9.924588 | 98.497442 | -977.546500 | 9701.745999 | -96285.829291 | 955597.159837 | -9.483908e+06 | 9.412388e+07 | -9.341407e+08 | 9.270961e+09 |
3 | 2.043272 | 4.174960 | 8.530580 | 17.430295 | 35.614834 | 72.770792 | 1.486905e+02 | 3.038152e+02 | 6.207771e+02 | 1.268416e+03 |
41 | 6.656568 | 44.309900 | 294.951869 | 1963.367222 | 13069.287756 | 86996.604895 | 5.790988e+05 | 3.854811e+06 | 2.565981e+07 | 1.708063e+08 |
28 | 3.165616 | 10.021126 | 31.723040 | 100.422972 | 317.900593 | 1006.351284 | 3.185722e+03 | 1.008477e+04 | 3.192452e+04 | 1.010608e+05 |
43 | 1.236174 | 1.528127 | 1.889031 | 2.335171 | 2.886679 | 3.568438 | 4.411211e+00 | 5.453025e+00 | 6.740889e+00 | 8.332913e+00 |
30 | -6.470932 | 41.872959 | -270.957061 | 1753.344667 | -11345.773804 | 73417.728768 | -4.750811e+05 | 3.074218e+06 | -1.989305e+07 | 1.287266e+08 |
47 | -3.012487 | 9.075077 | -27.338548 | 82.357015 | -248.099419 | 747.396219 | -2.251521e+03 | 6.782678e+03 | -2.043273e+04 | 6.155332e+04 |
17 | 7.135269 | 50.912060 | 363.271234 | 2592.037885 | 18494.886953 | 131965.989144 | 9.416128e+05 | 6.718660e+06 | 4.793945e+07 | 3.420608e+08 |
21 | 4.510288 | 20.342695 | 91.751404 | 413.825225 | 1866.470799 | 8418.320181 | 3.796905e+04 | 1.712513e+05 | 7.723927e+05 | 3.483713e+06 |
29 | 5.222245 | 27.271843 | 142.420249 | 743.753442 | 3884.062733 | 20283.527393 | 1.059256e+05 | 5.531692e+05 | 2.888785e+06 | 1.508594e+07 |
39 | -8.223553 | 67.626827 | -556.132805 | 4573.387686 | -37609.496782 | 309283.696327 | -2.543411e+06 | 2.091587e+07 | -1.720028e+08 | 1.414474e+09 |
18 | 4.196026 | 17.606638 | 73.877921 | 309.993714 | 1300.741838 | 5457.947214 | 2.290169e+04 | 9.609610e+04 | 4.032218e+05 | 1.691929e+06 |
22 | -9.336060 | 87.162025 | -813.749934 | 7597.218586 | -70928.092042 | 662188.955593 | -6.182236e+06 | 5.771773e+07 | -5.388562e+08 | 5.030794e+09 |
44 | -2.766259 | 7.652191 | -21.167945 | 58.556028 | -161.981161 | 448.081909 | -1.239511e+03 | 3.428808e+03 | -9.484973e+03 | 2.623790e+04 |
37 | -2.738399 | 7.498829 | -20.534785 | 56.232432 | -153.986833 | 421.677380 | -1.154721e+03 | 3.162086e+03 | -8.659054e+03 | 2.371194e+04 |
5 | -0.578506 | 0.334669 | -0.193608 | 0.112004 | -0.064795 | 0.037484 | -2.168480e-02 | 1.254479e-02 | -7.257238e-03 | 4.198356e-03 |
32 | -3.493023 | 12.201207 | -42.619094 | 148.869463 | -520.004411 | 1816.387202 | -6.344682e+03 | 2.216212e+04 | -7.741278e+04 | 2.704046e+05 |
27 | -2.965964 | 8.796943 | -26.091415 | 77.386199 | -229.524684 | 680.761955 | -2.019115e+03 | 5.988624e+03 | -1.776204e+04 | 5.268158e+04 |
46 | 4.825670 | 23.287096 | 112.375850 | 542.288824 | 2616.907175 | 12628.331726 | 6.094017e+04 | 2.940772e+05 | 1.419120e+06 | 6.848203e+06 |
12 | 3.660112 | 13.396416 | 49.032378 | 179.463974 | 656.858162 | 2404.174135 | 8.799545e+03 | 3.220732e+04 | 1.178824e+05 | 4.314626e+05 |
2 | -4.478046 | 20.052899 | -89.797810 | 402.118751 | -1800.706392 | 8063.646628 | -3.610938e+04 | 1.616995e+05 | -7.240978e+05 | 3.242544e+06 |
8 | 0.108538 | 0.011781 | 0.001279 | 0.000139 | 0.000015 | 0.000002 | 1.774525e-07 | 1.926040e-08 | 2.090492e-09 | 2.268986e-10 |
36 | 3.727977 | 13.897811 | 51.810717 | 193.149152 | 720.055560 | 2684.350427 | 1.000720e+04 | 3.730660e+04 | 1.390781e+05 | 5.184800e+05 |
4 | 4.850593 | 23.528255 | 114.125996 | 553.578791 | 2685.185564 | 13024.743051 | 6.317773e+04 | 3.064495e+05 | 1.486462e+06 | 7.210222e+06 |
33 | -8.747075 | 76.511319 | -669.250243 | 5853.982010 | -51205.219163 | 447895.887816 | -3.917779e+06 | 3.426911e+07 | -2.997544e+08 | 2.621974e+09 |
26 | 6.520774 | 42.520497 | 277.266559 | 1807.992633 | 11789.511780 | 76876.744660 | 5.012959e+05 | 3.268837e+06 | 2.131535e+07 | 1.389926e+08 |
Performing polynomial regression for each degree¶
from sklearn.linear_model import LinearRegression
The next cell isn’t doing anything; it’s just showing how sub_cols
is changing.
for i in range(1,max_deg+1):
sub_cols = cols[:i]
reg = LinearRegression()
print(sub_cols)
['x1']
['x1', 'x2']
['x1', 'x2', 'x3']
['x1', 'x2', 'x3', 'x4']
['x1', 'x2', 'x3', 'x4', 'x5']
['x1', 'x2', 'x3', 'x4', 'x5', 'x6']
['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7']
['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8']
['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10']
X_train
x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | |
---|---|---|---|---|---|---|---|---|---|---|
13 | -8.982275 | 80.681271 | -724.701394 | 6509.467488 | -58469.829551 | 525192.110405 | -4.717420e+06 | 4.237317e+07 | -3.806075e+08 | 3.418721e+09 |
11 | 6.449735 | 41.599085 | 268.303082 | 1730.483838 | 11161.162581 | 71986.543546 | 4.642941e+05 | 2.994574e+06 | 1.931421e+07 | 1.245716e+08 |
1 | 6.465018 | 41.796463 | 270.214901 | 1746.944309 | 11294.027098 | 73016.092970 | 4.720504e+05 | 3.051814e+06 | 1.973004e+07 | 1.275550e+08 |
25 | 0.371593 | 0.138081 | 0.051310 | 0.019066 | 0.007085 | 0.002633 | 9.782985e-04 | 3.635287e-04 | 1.350847e-04 | 5.019650e-05 |
16 | -3.942982 | 15.547104 | -61.301947 | 241.712453 | -953.067768 | 3757.928725 | -1.481744e+04 | 5.842491e+04 | -2.303683e+05 | 9.083382e+05 |
45 | -4.457650 | 19.870645 | -88.576385 | 394.842539 | -1760.069919 | 7845.776000 | -3.497372e+04 | 1.559006e+05 | -6.949505e+05 | 3.097846e+06 |
15 | 5.481667 | 30.048674 | 164.716823 | 902.922780 | 4949.522041 | 27131.631828 | 1.487266e+05 | 8.152695e+05 | 4.469036e+06 | 2.449777e+07 |
42 | -9.804880 | 96.135668 | -942.598671 | 9242.066675 | -90617.352934 | 888492.254125 | -8.711560e+06 | 8.541580e+07 | -8.374916e+08 | 8.211505e+09 |
20 | -0.690041 | 0.476157 | -0.328568 | 0.226726 | -0.156450 | 0.107957 | -7.449486e-02 | 5.140454e-02 | -3.547126e-02 | 2.447664e-02 |
35 | 4.852230 | 23.544134 | 114.241552 | 554.326269 | 2689.718466 | 13051.132219 | 6.332709e+04 | 3.072776e+05 | 1.490982e+06 | 7.234585e+06 |
9 | 4.556122 | 20.758251 | 94.577131 | 430.904978 | 1963.255801 | 8944.833631 | 4.075376e+04 | 1.856791e+05 | 8.459767e+05 | 3.854373e+06 |
10 | -6.299257 | 39.680640 | -249.958554 | 1574.553197 | -9918.515413 | 62479.278687 | -3.935730e+05 | 2.479218e+06 | -1.561723e+07 | 9.837695e+07 |
Notice how y_train
holds the exact same rows (look at the index for both X_train
and y_train
).
y_train
13 2.574974
11 130.940520
1 73.954907
25 1.134863
16 -35.901888
45 -26.760377
15 24.441828
42 131.476240
20 61.591699
35 -13.988943
9 1.089803
10 -7.251101
Name: y, dtype: float64
Evaluating the performance, Part 1¶
from sklearn.metrics import mean_squared_error
mse_train_dict = {}
for i in range(1,max_deg+1):
sub_cols = cols[:i]
reg = LinearRegression()
reg.fit(X_train[sub_cols], y_train)
# what if we predict also on the training set?
mse_train_dict[i] = mean_squared_error(reg.predict(X_train[sub_cols]), y_train)
What you should think when you look at the following dictionary is that, because mean squared error is decreasing, it looks like the model is improving.
But there is a major problem: we both fit and predict on the training set. Especially for higher degree polynomials, there is a big risk taht we are over-fitting the data.
mse_train_dict
{1: 3006.2441403930584,
2: 1845.1659476797647,
3: 1820.054196587029,
4: 590.3862081005426,
5: 589.0984843939519,
6: 523.8762601493494,
7: 373.279496780059,
8: 206.85098887525842,
9: 204.99331332780278,
10: 101.47175254241118}
Evaluating the performance, Part 2¶
To get meaningful values of MSE (Mean Squared Error), we should always evaluate on a test set, that was not used during training (fitting).
mse_test_dict = {}
for i in range(1,max_deg+1):
sub_cols = cols[:i]
reg = LinearRegression()
reg.fit(X_train[sub_cols], y_train)
df[f"Pred{i}"] = reg.predict(df[sub_cols])
mse_test_dict[i] = mean_squared_error(reg.predict(X_test[sub_cols]), y_test)
Notice how degree 2 is where we find the best MSE (unlike in the previous dictionary, where degree 10 looked best). This reflects the fact that the true underlying function has degree 2. (There is some randomness in this procedure. When I tried other values like random_state=4
, I did not always find that degree 2 was the best. Sometimes degree 3 or degree 4 appeared the best.)
mse_test_dict
{1: 4960.848889431734,
2: 3724.5263419375206,
3: 3742.5488620098768,
4: 6252.874102524025,
5: 6253.412384993094,
6: 7016.575636378879,
7: 5609.641936177731,
8: 9481.527365671222,
9: 12142.080457403497,
10: 103395.24260458215}
Plotting the polynomial fits¶
This code is similar to what we did Friday of last week, but first we’re going to add a column called “In_train” recording whether the point is in the training set or the test set.
df["In_train"] = "test"
X_train.index
Int64Index([13, 11, 1, 25, 16, 45, 15, 42, 20, 35, 9, 10], dtype='int64')
df.loc[X_train.index, "In_train"] = "train"
df.head()
x | y_true | y | x1 | x2 | x3 | x4 | x5 | x6 | x7 | ... | Pred2 | Pred3 | Pred4 | Pred5 | Pred6 | Pred7 | Pred8 | Pred9 | Pred10 | In_train | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -3.329208 | -18.207589 | -117.484900 | -3.329208 | 11.083626 | -36.899694 | 122.846756 | -408.982395 | 1361.587441 | -4533.007730 | ... | -16.647390 | -19.954037 | 6.715567 | 5.215954 | -7.995404 | -18.576185 | 3.146424 | -5.566199 | -15.676053 | test |
1 | 6.465018 | 74.160562 | 73.954907 | 6.465018 | 41.796463 | 270.214901 | 1746.944309 | 11294.027098 | 73016.092970 | 472050.384357 | ... | 70.916916 | 65.455116 | 101.808679 | 102.350612 | 105.135066 | 100.392589 | 103.474571 | 103.854390 | 97.078223 | train |
2 | -4.478046 | -7.670062 | -13.810089 | -4.478046 | 20.052899 | -89.797810 | 402.118751 | -1800.706392 | 8063.646628 | -36109.383086 | ... | -9.287827 | -14.954228 | -16.790957 | -18.030386 | -17.560873 | -25.152785 | -38.972803 | -35.768769 | -25.492479 | test |
3 | 2.043272 | -7.925152 | 19.461182 | 2.043272 | 4.174960 | 8.530580 | 17.430295 | 35.614834 | 72.770792 | 148.690523 | ... | -1.930187 | 6.789503 | -5.399120 | -3.567399 | 30.239546 | -11.206589 | -73.991439 | -94.222921 | 94.585113 | test |
4 | 4.850593 | 36.485466 | 22.375230 | 4.850593 | 23.528255 | 114.125996 | 553.578791 | 2685.185564 | 13024.743051 | 63177.731115 | ... | 37.963743 | 42.179199 | -1.465158 | -2.100513 | -2.486611 | -5.545758 | 0.475225 | 2.001401 | -21.612846 | test |
5 rows × 24 columns
c = alt.Chart(df).mark_circle().encode(
x="x",
y="y",
color="In_train"
)
c_true = alt.Chart(df).mark_line(color="black").encode(
x="x",
y="y_true",
)
chart_list = []
for i in range(1,max_deg+1):
c_temp = alt.Chart(df).mark_line(color="red", clip=True).encode(
x="x",
y=alt.Y(f"Pred{i}", scale=alt.Scale(domain=(-100,300))),
)
chart_list.append(c_temp)
all_charts = [c+c_true+d for d in chart_list]
Both the training error and the test error are high when degree is 1. That corresponds to underfitting, and in the picture, it corresponds to the straight line not being able to fit the data well.
The training error is low but the test error is high for large degrees (like degree 8). This corresponds to overfitting. Look at the 8th plot below, and notice how the red degree 8 polynomial is very closely following the training data (the orange points).
Notice how the polynomial fits the true secret polynomial (the black quadratic curve) closely when we are using degree 2 or degree 3.
alt.vconcat(*all_charts)
If we wanted the polynomials to look smoother, we could replace df
with something like df_plot
that we used on Friday of last week.