{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"cell_id": "e4e4c69a-c97b-49fb-8c83-2a5b1d3471da",
"deepnote_cell_height": 224,
"deepnote_cell_type": "markdown",
"tags": []
},
"source": [
"# Polynomial Regression 2\n",
"\n",
"I think the most important concept in machine learning is the concept of **overfitting**. The idea is if you have too flexible of a model (relative to the number of data points), then your model can match the data very closely, but it will do a poor job of predicting future values.\n",
"\n",
"When performing polynomial regression, the higher the degree of the polynomial, the more flexible the model is. So when the degree of the polynomial is higher, there is a greater risk of overfitting."
]
},
{
"cell_type": "markdown",
"metadata": {
"cell_id": "2f6571ef76ce4b94b0d895a99a3fa62c",
"deepnote_cell_height": 70,
"deepnote_cell_type": "markdown",
"tags": []
},
"source": [
"## The taxis dataset\n",
"\n",
"The taxis dataset contains information about approximately 6000 taxi rides in New York City. Our goal is to try to model the total cost of the taxi ride using the distance."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cell_id": "90b5eedfa9f141a1bd52062da5e353cb",
"deepnote_cell_height": 135,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 5065,
"execution_start": 1651694170552,
"source_hash": "64a83654",
"tags": []
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import altair as alt\n",
"alt.data_transformers.enable('default', max_rows=10000)\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"cell_id": "4a2afe954c8f42d6a3ceb8c0462b216b",
"deepnote_cell_height": 81,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 431,
"execution_start": 1651694181207,
"source_hash": "eaa10c21",
"tags": []
},
"outputs": [],
"source": [
"df = sns.load_dataset(\"taxis\").dropna()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"cell_id": "220efa9a0ebe4327b3a1b13209447276",
"deepnote_cell_height": 395,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 134,
"execution_start": 1651694539556,
"source_hash": "c085b6ba",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" pickup | \n",
" dropoff | \n",
" passengers | \n",
" distance | \n",
" fare | \n",
" tip | \n",
" tolls | \n",
" total | \n",
" color | \n",
" payment | \n",
" pickup_zone | \n",
" dropoff_zone | \n",
" pickup_borough | \n",
" dropoff_borough | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2019-03-23 20:21:09 | \n",
" 2019-03-23 20:27:24 | \n",
" 1 | \n",
" 1.60 | \n",
" 7.0 | \n",
" 2.15 | \n",
" 0.0 | \n",
" 12.95 | \n",
" yellow | \n",
" credit card | \n",
" Lenox Hill West | \n",
" UN/Turtle Bay South | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 1 | \n",
" 2019-03-04 16:11:55 | \n",
" 2019-03-04 16:19:00 | \n",
" 1 | \n",
" 0.79 | \n",
" 5.0 | \n",
" 0.00 | \n",
" 0.0 | \n",
" 9.30 | \n",
" yellow | \n",
" cash | \n",
" Upper West Side South | \n",
" Upper West Side South | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 2 | \n",
" 2019-03-27 17:53:01 | \n",
" 2019-03-27 18:00:25 | \n",
" 1 | \n",
" 1.37 | \n",
" 7.5 | \n",
" 2.36 | \n",
" 0.0 | \n",
" 14.16 | \n",
" yellow | \n",
" credit card | \n",
" Alphabet City | \n",
" West Village | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 3 | \n",
" 2019-03-10 01:23:59 | \n",
" 2019-03-10 01:49:51 | \n",
" 1 | \n",
" 7.70 | \n",
" 27.0 | \n",
" 6.15 | \n",
" 0.0 | \n",
" 36.95 | \n",
" yellow | \n",
" credit card | \n",
" Hudson Sq | \n",
" Yorkville West | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 4 | \n",
" 2019-03-30 13:27:42 | \n",
" 2019-03-30 13:37:14 | \n",
" 3 | \n",
" 2.16 | \n",
" 9.0 | \n",
" 1.10 | \n",
" 0.0 | \n",
" 13.40 | \n",
" yellow | \n",
" credit card | \n",
" Midtown East | \n",
" Yorkville West | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" pickup dropoff passengers distance fare tip \\\n",
"0 2019-03-23 20:21:09 2019-03-23 20:27:24 1 1.60 7.0 2.15 \n",
"1 2019-03-04 16:11:55 2019-03-04 16:19:00 1 0.79 5.0 0.00 \n",
"2 2019-03-27 17:53:01 2019-03-27 18:00:25 1 1.37 7.5 2.36 \n",
"3 2019-03-10 01:23:59 2019-03-10 01:49:51 1 7.70 27.0 6.15 \n",
"4 2019-03-30 13:27:42 2019-03-30 13:37:14 3 2.16 9.0 1.10 \n",
"\n",
" tolls total color payment pickup_zone \\\n",
"0 0.0 12.95 yellow credit card Lenox Hill West \n",
"1 0.0 9.30 yellow cash Upper West Side South \n",
"2 0.0 14.16 yellow credit card Alphabet City \n",
"3 0.0 36.95 yellow credit card Hudson Sq \n",
"4 0.0 13.40 yellow credit card Midtown East \n",
"\n",
" dropoff_zone pickup_borough dropoff_borough \n",
"0 UN/Turtle Bay South Manhattan Manhattan \n",
"1 Upper West Side South Manhattan Manhattan \n",
"2 West Village Manhattan Manhattan \n",
"3 Yorkville West Manhattan Manhattan \n",
"4 Yorkville West Manhattan Manhattan "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following plot shows that the data follows a linear model very closely. We are going to instead try using polynomial regression with a high degree (degree 9 in this case)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"cell_id": "94eff324ca1448cb9b9a5cb5cbf6396f",
"deepnote_cell_height": 512,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
361
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 1017,
"execution_start": 1651694627624,
"source_hash": "ede7be0d",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
""
],
"text/plain": [
"alt.Chart(...)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"alt.Chart(df).mark_circle().encode(\n",
" x=\"distance\",\n",
" y=\"total\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we fit a degree 9 polynomial to all 6000+ data points, it will look very good. Instead we will only use 40 of the data points. In general, the fewer the data points you use, the more risk there is for overfitting."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"cell_id": "ce9fa8e7091f45a98aec0c15f1d58d5c",
"deepnote_cell_height": 81,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 6,
"execution_start": 1651694778361,
"source_hash": "746a4dbc",
"tags": []
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following is *not* the usual way to call `train_test_split`. Below we will use the usual way, which involves list unpacking.\n",
"\n",
"Here `train_size=40` says we want to choose 40 random rows from the DataFrame `df`. If we had instead used `train_size=0.4`, that would say we want to choose 40% of the rows."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"cell_id": "2c9de1ad14144e8aa06eccb2063508a2",
"deepnote_cell_height": 99,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 6,
"execution_start": 1651694850545,
"source_hash": "eb474865",
"tags": []
},
"outputs": [],
"source": [
"# use 40 data points\n",
"a = train_test_split(df, train_size=40)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"cell_id": "9e8eb4fae0e94e5aa4bbba7cbb6489e1",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 4,
"execution_start": 1651694865156,
"source_hash": "1478ef",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"list"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(a)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"cell_id": "a3d6df3a7730438fa11669f92061a507",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 11,
"execution_start": 1651694877667,
"source_hash": "f64cd242",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(a)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"cell_id": "fa81cbe73a274b7589e49352de29782a",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 6,
"execution_start": 1651694907861,
"source_hash": "704ff305",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"pandas.core.frame.DataFrame"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(a[0])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"cell_id": "d8713820e1794d1d8af5e6ac524146e9",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 9,
"execution_start": 1651694918308,
"source_hash": "2abccaa2",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"(40, 14)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"cell_id": "8ffd16afd5f04e30a756acebe7fbb678",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 17,
"execution_start": 1651694930593,
"source_hash": "bdee13fc",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"(6301, 14)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a[1].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Currently `a[0]` contains 40 rows from the DataFrame, and `a[1]` contains the remaining 6301 rows."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"cell_id": "9141f9534a08453fb2275e0f7debb008",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 11,
"execution_start": 1651694935057,
"source_hash": "14f60b8f",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"(6341, 14)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The rows are chosen at random. (They aren't even presented in order.)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"cell_id": "67e4e225eb8246a2b27d765d5c4fabe4",
"deepnote_cell_height": 395,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 62,
"execution_start": 1651694995970,
"source_hash": "b62c9399",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" pickup | \n",
" dropoff | \n",
" passengers | \n",
" distance | \n",
" fare | \n",
" tip | \n",
" tolls | \n",
" total | \n",
" color | \n",
" payment | \n",
" pickup_zone | \n",
" dropoff_zone | \n",
" pickup_borough | \n",
" dropoff_borough | \n",
"
\n",
" \n",
" \n",
" \n",
" 1855 | \n",
" 2019-03-29 14:20:18 | \n",
" 2019-03-29 14:36:35 | \n",
" 6 | \n",
" 2.72 | \n",
" 12.5 | \n",
" 3.16 | \n",
" 0.0 | \n",
" 18.96 | \n",
" yellow | \n",
" credit card | \n",
" Hudson Sq | \n",
" Murray Hill | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 5029 | \n",
" 2019-03-21 06:51:40 | \n",
" 2019-03-21 07:07:11 | \n",
" 1 | \n",
" 3.39 | \n",
" 14.0 | \n",
" 2.75 | \n",
" 0.0 | \n",
" 20.05 | \n",
" yellow | \n",
" credit card | \n",
" Central Park | \n",
" Times Sq/Theatre District | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 5277 | \n",
" 2019-03-19 04:54:04 | \n",
" 2019-03-19 05:00:05 | \n",
" 2 | \n",
" 1.65 | \n",
" 7.5 | \n",
" 2.00 | \n",
" 0.0 | \n",
" 13.30 | \n",
" yellow | \n",
" credit card | \n",
" East Chelsea | \n",
" Clinton East | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 247 | \n",
" 2019-03-07 18:49:41 | \n",
" 2019-03-07 19:07:05 | \n",
" 1 | \n",
" 2.75 | \n",
" 13.0 | \n",
" 3.46 | \n",
" 0.0 | \n",
" 20.76 | \n",
" yellow | \n",
" credit card | \n",
" Yorkville West | \n",
" Lincoln Square East | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 912 | \n",
" 2019-03-15 23:13:44 | \n",
" 2019-03-15 23:24:47 | \n",
" 1 | \n",
" 1.70 | \n",
" 9.5 | \n",
" 2.65 | \n",
" 0.0 | \n",
" 15.95 | \n",
" yellow | \n",
" credit card | \n",
" Lincoln Square West | \n",
" East Chelsea | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" pickup dropoff passengers distance fare \\\n",
"1855 2019-03-29 14:20:18 2019-03-29 14:36:35 6 2.72 12.5 \n",
"5029 2019-03-21 06:51:40 2019-03-21 07:07:11 1 3.39 14.0 \n",
"5277 2019-03-19 04:54:04 2019-03-19 05:00:05 2 1.65 7.5 \n",
"247 2019-03-07 18:49:41 2019-03-07 19:07:05 1 2.75 13.0 \n",
"912 2019-03-15 23:13:44 2019-03-15 23:24:47 1 1.70 9.5 \n",
"\n",
" tip tolls total color payment pickup_zone \\\n",
"1855 3.16 0.0 18.96 yellow credit card Hudson Sq \n",
"5029 2.75 0.0 20.05 yellow credit card Central Park \n",
"5277 2.00 0.0 13.30 yellow credit card East Chelsea \n",
"247 3.46 0.0 20.76 yellow credit card Yorkville West \n",
"912 2.65 0.0 15.95 yellow credit card Lincoln Square West \n",
"\n",
" dropoff_zone pickup_borough dropoff_borough \n",
"1855 Murray Hill Manhattan Manhattan \n",
"5029 Times Sq/Theatre District Manhattan Manhattan \n",
"5277 Clinton East Manhattan Manhattan \n",
"247 Lincoln Square East Manhattan Manhattan \n",
"912 East Chelsea Manhattan Manhattan "
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a[0].head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The usual way to call `train_test_split` is to use list unpacking, which is what we do in the following code."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"cell_id": "2c3d826550ed4a828eb7476c1458e8aa",
"deepnote_cell_height": 99,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 6,
"execution_start": 1651695117640,
"source_hash": "cf8362a5",
"tags": []
},
"outputs": [],
"source": [
"# list unpacking\n",
"df_train, df_test = train_test_split(df, train_size=40)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now `df_train` will contain a different set of 40 rows from `df`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"cell_id": "a1432e5d1bb94274a5d02f0c1e47d660",
"deepnote_cell_height": 395,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 60,
"execution_start": 1651695119730,
"source_hash": "5e2a9c26",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" pickup | \n",
" dropoff | \n",
" passengers | \n",
" distance | \n",
" fare | \n",
" tip | \n",
" tolls | \n",
" total | \n",
" color | \n",
" payment | \n",
" pickup_zone | \n",
" dropoff_zone | \n",
" pickup_borough | \n",
" dropoff_borough | \n",
"
\n",
" \n",
" \n",
" \n",
" 1091 | \n",
" 2019-03-05 20:17:31 | \n",
" 2019-03-05 20:35:21 | \n",
" 1 | \n",
" 4.20 | \n",
" 16.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 19.8 | \n",
" yellow | \n",
" cash | \n",
" Penn Station/Madison Sq West | \n",
" Battery Park City | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 5459 | \n",
" 2019-03-29 16:25:22 | \n",
" 2019-03-29 17:15:57 | \n",
" 1 | \n",
" 12.76 | \n",
" 39.5 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 41.3 | \n",
" green | \n",
" credit card | \n",
" Brighton Beach | \n",
" Richmond Hill | \n",
" Brooklyn | \n",
" Queens | \n",
"
\n",
" \n",
" 4164 | \n",
" 2019-03-20 09:11:06 | \n",
" 2019-03-20 09:21:44 | \n",
" 2 | \n",
" 0.83 | \n",
" 8.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 11.3 | \n",
" yellow | \n",
" cash | \n",
" Midtown South | \n",
" Midtown East | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 819 | \n",
" 2019-03-17 19:30:19 | \n",
" 2019-03-17 19:41:39 | \n",
" 1 | \n",
" 2.10 | \n",
" 9.5 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 12.8 | \n",
" yellow | \n",
" cash | \n",
" Lincoln Square East | \n",
" Upper West Side North | \n",
" Manhattan | \n",
" Manhattan | \n",
"
\n",
" \n",
" 6136 | \n",
" 2019-03-05 19:01:21 | \n",
" 2019-03-05 19:05:56 | \n",
" 1 | \n",
" 0.87 | \n",
" 5.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 6.8 | \n",
" green | \n",
" cash | \n",
" Elmhurst | \n",
" Elmhurst | \n",
" Queens | \n",
" Queens | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" pickup dropoff passengers distance fare \\\n",
"1091 2019-03-05 20:17:31 2019-03-05 20:35:21 1 4.20 16.0 \n",
"5459 2019-03-29 16:25:22 2019-03-29 17:15:57 1 12.76 39.5 \n",
"4164 2019-03-20 09:11:06 2019-03-20 09:21:44 2 0.83 8.0 \n",
"819 2019-03-17 19:30:19 2019-03-17 19:41:39 1 2.10 9.5 \n",
"6136 2019-03-05 19:01:21 2019-03-05 19:05:56 1 0.87 5.0 \n",
"\n",
" tip tolls total color payment pickup_zone \\\n",
"1091 0.0 0.0 19.8 yellow cash Penn Station/Madison Sq West \n",
"5459 0.0 0.0 41.3 green credit card Brighton Beach \n",
"4164 0.0 0.0 11.3 yellow cash Midtown South \n",
"819 0.0 0.0 12.8 yellow cash Lincoln Square East \n",
"6136 0.0 0.0 6.8 green cash Elmhurst \n",
"\n",
" dropoff_zone pickup_borough dropoff_borough \n",
"1091 Battery Park City Manhattan Manhattan \n",
"5459 Richmond Hill Brooklyn Queens \n",
"4164 Midtown East Manhattan Manhattan \n",
"819 Upper West Side North Manhattan Manhattan \n",
"6136 Elmhurst Queens Queens "
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"cell_id": "c0406a4f79af4bd3a80d328ce4c5f9ff",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 170,
"execution_start": 1651695196957,
"source_hash": "db981d08",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"(40, 14)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_train.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"cell_id": "fdbe95819f9646e8bb067705e18c28e5",
"deepnote_cell_height": 70,
"deepnote_cell_type": "markdown",
"tags": []
},
"source": [
"## Fitting polynomial regression"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"cell_id": "65ef0388c69a4933b65e75ba6936d998",
"deepnote_cell_height": 135,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 16,
"execution_start": 1651695572487,
"source_hash": "a285ad9",
"tags": []
},
"outputs": [],
"source": [
"cols = []\n",
"for deg in range(1,10):\n",
" cols.append(f\"d{deg}\")\n",
" df_train[f\"d{deg}\"] = df_train[\"distance\"]**deg"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is a more DRY approach, where we only have to type `f\"d{deg}\"` once."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"cell_id": "ec9358bd85a2452f9736696123a43ba8",
"deepnote_cell_height": 153,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 6,
"execution_start": 1651695621654,
"source_hash": "6dc14bfc",
"tags": []
},
"outputs": [],
"source": [
"cols = []\n",
"for deg in range(1,10):\n",
" c = f\"d{deg}\"\n",
" cols.append(c)\n",
" df_train[c] = df_train[\"distance\"]**deg"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"cell_id": "37924b8fa2e141abb05c0f0fee4bb327",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 7,
"execution_start": 1651695624088,
"source_hash": "cb73c151",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"['d1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9']"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cols"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice how 9 rows have been added to `df_train`. They contain the 1st, 2nd, ..., 9th powers of the values in the \"distance\" column."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"cell_id": "03a6321f99094919b5be5bda6e321a34",
"deepnote_cell_height": 600,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 179,
"execution_start": 1651695444130,
"source_hash": "e7a4a97c",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" pickup | \n",
" dropoff | \n",
" passengers | \n",
" distance | \n",
" fare | \n",
" tip | \n",
" tolls | \n",
" total | \n",
" color | \n",
" payment | \n",
" ... | \n",
" dropoff_borough | \n",
" d1 | \n",
" d2 | \n",
" d3 | \n",
" d4 | \n",
" d5 | \n",
" d6 | \n",
" d7 | \n",
" d8 | \n",
" d9 | \n",
"
\n",
" \n",
" \n",
" \n",
" 1091 | \n",
" 2019-03-05 20:17:31 | \n",
" 2019-03-05 20:35:21 | \n",
" 1 | \n",
" 4.20 | \n",
" 16.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 19.8 | \n",
" yellow | \n",
" cash | \n",
" ... | \n",
" Manhattan | \n",
" 4.20 | \n",
" 17.6400 | \n",
" 74.088000 | \n",
" 311.169600 | \n",
" 1306.912320 | \n",
" 5.489032e+03 | \n",
" 2.305393e+04 | \n",
" 9.682652e+04 | \n",
" 4.066714e+05 | \n",
"
\n",
" \n",
" 5459 | \n",
" 2019-03-29 16:25:22 | \n",
" 2019-03-29 17:15:57 | \n",
" 1 | \n",
" 12.76 | \n",
" 39.5 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 41.3 | \n",
" green | \n",
" credit card | \n",
" ... | \n",
" Queens | \n",
" 12.76 | \n",
" 162.8176 | \n",
" 2077.552576 | \n",
" 26509.570870 | \n",
" 338262.124298 | \n",
" 4.316225e+06 | \n",
" 5.507503e+07 | \n",
" 7.027573e+08 | \n",
" 8.967184e+09 | \n",
"
\n",
" \n",
" 4164 | \n",
" 2019-03-20 09:11:06 | \n",
" 2019-03-20 09:21:44 | \n",
" 2 | \n",
" 0.83 | \n",
" 8.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 11.3 | \n",
" yellow | \n",
" cash | \n",
" ... | \n",
" Manhattan | \n",
" 0.83 | \n",
" 0.6889 | \n",
" 0.571787 | \n",
" 0.474583 | \n",
" 0.393904 | \n",
" 3.269404e-01 | \n",
" 2.713605e-01 | \n",
" 2.252292e-01 | \n",
" 1.869403e-01 | \n",
"
\n",
" \n",
" 819 | \n",
" 2019-03-17 19:30:19 | \n",
" 2019-03-17 19:41:39 | \n",
" 1 | \n",
" 2.10 | \n",
" 9.5 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 12.8 | \n",
" yellow | \n",
" cash | \n",
" ... | \n",
" Manhattan | \n",
" 2.10 | \n",
" 4.4100 | \n",
" 9.261000 | \n",
" 19.448100 | \n",
" 40.841010 | \n",
" 8.576612e+01 | \n",
" 1.801089e+02 | \n",
" 3.782286e+02 | \n",
" 7.942800e+02 | \n",
"
\n",
" \n",
" 6136 | \n",
" 2019-03-05 19:01:21 | \n",
" 2019-03-05 19:05:56 | \n",
" 1 | \n",
" 0.87 | \n",
" 5.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 6.8 | \n",
" green | \n",
" cash | \n",
" ... | \n",
" Queens | \n",
" 0.87 | \n",
" 0.7569 | \n",
" 0.658503 | \n",
" 0.572898 | \n",
" 0.498421 | \n",
" 4.336262e-01 | \n",
" 3.772548e-01 | \n",
" 3.282117e-01 | \n",
" 2.855442e-01 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 23 columns
\n",
"
"
],
"text/plain": [
" pickup dropoff passengers distance fare \\\n",
"1091 2019-03-05 20:17:31 2019-03-05 20:35:21 1 4.20 16.0 \n",
"5459 2019-03-29 16:25:22 2019-03-29 17:15:57 1 12.76 39.5 \n",
"4164 2019-03-20 09:11:06 2019-03-20 09:21:44 2 0.83 8.0 \n",
"819 2019-03-17 19:30:19 2019-03-17 19:41:39 1 2.10 9.5 \n",
"6136 2019-03-05 19:01:21 2019-03-05 19:05:56 1 0.87 5.0 \n",
"\n",
" tip tolls total color payment ... dropoff_borough d1 \\\n",
"1091 0.0 0.0 19.8 yellow cash ... Manhattan 4.20 \n",
"5459 0.0 0.0 41.3 green credit card ... Queens 12.76 \n",
"4164 0.0 0.0 11.3 yellow cash ... Manhattan 0.83 \n",
"819 0.0 0.0 12.8 yellow cash ... Manhattan 2.10 \n",
"6136 0.0 0.0 6.8 green cash ... Queens 0.87 \n",
"\n",
" d2 d3 d4 d5 d6 \\\n",
"1091 17.6400 74.088000 311.169600 1306.912320 5.489032e+03 \n",
"5459 162.8176 2077.552576 26509.570870 338262.124298 4.316225e+06 \n",
"4164 0.6889 0.571787 0.474583 0.393904 3.269404e-01 \n",
"819 4.4100 9.261000 19.448100 40.841010 8.576612e+01 \n",
"6136 0.7569 0.658503 0.572898 0.498421 4.336262e-01 \n",
"\n",
" d7 d8 d9 \n",
"1091 2.305393e+04 9.682652e+04 4.066714e+05 \n",
"5459 5.507503e+07 7.027573e+08 8.967184e+09 \n",
"4164 2.713605e-01 2.252292e-01 1.869403e-01 \n",
"819 1.801089e+02 3.782286e+02 7.942800e+02 \n",
"6136 3.772548e-01 3.282117e-01 2.855442e-01 \n",
"\n",
"[5 rows x 23 columns]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_train.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember, as we saw last time, polynomial regression can be viewed as a special case of linear regression. (The reverse is also true and more obvious.)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"cell_id": "18595bda00f4433d9c2b54c31c6df588",
"deepnote_cell_height": 81,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 0,
"execution_start": 1651695714609,
"source_hash": "9527aab5",
"tags": []
},
"outputs": [],
"source": [
"from sklearn.linear_model import LinearRegression"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"cell_id": "0425f7bee2e6480ca26be1b0d37ad389",
"deepnote_cell_height": 81,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 1,
"execution_start": 1651695731464,
"source_hash": "cd72858e",
"tags": []
},
"outputs": [],
"source": [
"reg = LinearRegression()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is important to only fit on `df_train`, because we want to only use 40 data points, so that we can demonstrate overfitting more effectively."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"cell_id": "67ac2d89e0304659b4c6ea783d14502e",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 13,
"execution_start": 1651695879050,
"source_hash": "40287825",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"LinearRegression()"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg.fit(df_train[cols], df_train[\"total\"])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"cell_id": "c188ddcdc9074691a8521537fc2fdad8",
"deepnote_cell_height": 81,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 0,
"execution_start": 1651695926138,
"source_hash": "97a9514a",
"tags": []
},
"outputs": [],
"source": [
"df_train[\"Pred\"] = reg.predict(df_train[cols])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The red line in the following shows the polynomial. It doesn't look like a polynomial, because Altair (like Matlab and like Matplotlib) connects the points with straight lines. Below we will make a DataFrame with more input values, so that the polynomial will look smoother. "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"cell_id": "ca8ca92edd8b48f5951213896094eff8",
"deepnote_cell_height": 633,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
356
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 103,
"execution_start": 1651696050196,
"source_hash": "d9da7148",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
""
],
"text/plain": [
"alt.LayerChart(...)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = alt.Chart(df_train).mark_circle().encode(\n",
" x=\"distance\",\n",
" y=\"total\"\n",
")\n",
"\n",
"c9 = alt.Chart(df_train).mark_line(color=\"red\").encode(\n",
" x=\"distance\",\n",
" y=\"Pred\"\n",
")\n",
"\n",
"c+c9"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"cell_id": "4970705ce2554ae78dd8758fa52ada9e",
"deepnote_cell_height": 81,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 8,
"execution_start": 1651696300546,
"source_hash": "c2602aa8",
"tags": []
},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"cell_id": "20d17375fb7546518bdb1ab22268030e",
"deepnote_cell_height": 81,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 4,
"execution_start": 1651697263223,
"source_hash": "b4f635bd",
"tags": []
},
"outputs": [],
"source": [
"df_plot = pd.DataFrame({\"distance\": np.arange(0,40,0.1)})"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"cell_id": "e810541fbdc24b4b81f25415543c9122",
"deepnote_cell_height": 600,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 695,
"execution_start": 1651697217800,
"source_hash": "46ed7f7a",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" distance | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.1 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.2 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.3 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.4 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 395 | \n",
" 39.5 | \n",
"
\n",
" \n",
" 396 | \n",
" 39.6 | \n",
"
\n",
" \n",
" 397 | \n",
" 39.7 | \n",
"
\n",
" \n",
" 398 | \n",
" 39.8 | \n",
"
\n",
" \n",
" 399 | \n",
" 39.9 | \n",
"
\n",
" \n",
"
\n",
"
400 rows × 1 columns
\n",
"
"
],
"text/plain": [
" distance\n",
"0 0.0\n",
"1 0.1\n",
"2 0.2\n",
"3 0.3\n",
"4 0.4\n",
".. ...\n",
"395 39.5\n",
"396 39.6\n",
"397 39.7\n",
"398 39.8\n",
"399 39.9\n",
"\n",
"[400 rows x 1 columns]"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_plot"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"cell_id": "7b809031021648249deb4b653e87d202",
"deepnote_cell_height": 171,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 19,
"execution_start": 1651697219353,
"source_hash": "50d4c81e",
"tags": []
},
"outputs": [],
"source": [
"cols = []\n",
"for deg in range(1,10):\n",
" c = f\"d{deg}\"\n",
" cols.append(c)\n",
" df_train[c] = df_train[\"distance\"]**deg\n",
" df_plot[c] = df_plot[\"distance\"]**deg"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"cell_id": "2f4fff32593646ffbe5bf69359d74e7c",
"deepnote_cell_height": 600,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 76,
"execution_start": 1651697220201,
"source_hash": "46ed7f7a",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" distance | \n",
" d1 | \n",
" d2 | \n",
" d3 | \n",
" d4 | \n",
" d5 | \n",
" d6 | \n",
" d7 | \n",
" d8 | \n",
" d9 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.00 | \n",
" 0.000 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.1 | \n",
" 0.1 | \n",
" 0.01 | \n",
" 0.001 | \n",
" 1.000000e-04 | \n",
" 1.000000e-05 | \n",
" 1.000000e-06 | \n",
" 1.000000e-07 | \n",
" 1.000000e-08 | \n",
" 1.000000e-09 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.2 | \n",
" 0.2 | \n",
" 0.04 | \n",
" 0.008 | \n",
" 1.600000e-03 | \n",
" 3.200000e-04 | \n",
" 6.400000e-05 | \n",
" 1.280000e-05 | \n",
" 2.560000e-06 | \n",
" 5.120000e-07 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.3 | \n",
" 0.3 | \n",
" 0.09 | \n",
" 0.027 | \n",
" 8.100000e-03 | \n",
" 2.430000e-03 | \n",
" 7.290000e-04 | \n",
" 2.187000e-04 | \n",
" 6.561000e-05 | \n",
" 1.968300e-05 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.4 | \n",
" 0.4 | \n",
" 0.16 | \n",
" 0.064 | \n",
" 2.560000e-02 | \n",
" 1.024000e-02 | \n",
" 4.096000e-03 | \n",
" 1.638400e-03 | \n",
" 6.553600e-04 | \n",
" 2.621440e-04 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 395 | \n",
" 39.5 | \n",
" 39.5 | \n",
" 1560.25 | \n",
" 61629.875 | \n",
" 2.434380e+06 | \n",
" 9.615801e+07 | \n",
" 3.798241e+09 | \n",
" 1.500305e+11 | \n",
" 5.926206e+12 | \n",
" 2.340851e+14 | \n",
"
\n",
" \n",
" 396 | \n",
" 39.6 | \n",
" 39.6 | \n",
" 1568.16 | \n",
" 62099.136 | \n",
" 2.459126e+06 | \n",
" 9.738138e+07 | \n",
" 3.856303e+09 | \n",
" 1.527096e+11 | \n",
" 6.047300e+12 | \n",
" 2.394731e+14 | \n",
"
\n",
" \n",
" 397 | \n",
" 39.7 | \n",
" 39.7 | \n",
" 1576.09 | \n",
" 62570.773 | \n",
" 2.484060e+06 | \n",
" 9.861717e+07 | \n",
" 3.915102e+09 | \n",
" 1.554295e+11 | \n",
" 6.170553e+12 | \n",
" 2.449709e+14 | \n",
"
\n",
" \n",
" 398 | \n",
" 39.8 | \n",
" 39.8 | \n",
" 1584.04 | \n",
" 63044.792 | \n",
" 2.509183e+06 | \n",
" 9.986547e+07 | \n",
" 3.974646e+09 | \n",
" 1.581909e+11 | \n",
" 6.295998e+12 | \n",
" 2.505807e+14 | \n",
"
\n",
" \n",
" 399 | \n",
" 39.9 | \n",
" 39.9 | \n",
" 1592.01 | \n",
" 63521.199 | \n",
" 2.534496e+06 | \n",
" 1.011264e+08 | \n",
" 4.034943e+09 | \n",
" 1.609942e+11 | \n",
" 6.423669e+12 | \n",
" 2.563044e+14 | \n",
"
\n",
" \n",
"
\n",
"
400 rows × 10 columns
\n",
"
"
],
"text/plain": [
" distance d1 d2 d3 d4 d5 \\\n",
"0 0.0 0.0 0.00 0.000 0.000000e+00 0.000000e+00 \n",
"1 0.1 0.1 0.01 0.001 1.000000e-04 1.000000e-05 \n",
"2 0.2 0.2 0.04 0.008 1.600000e-03 3.200000e-04 \n",
"3 0.3 0.3 0.09 0.027 8.100000e-03 2.430000e-03 \n",
"4 0.4 0.4 0.16 0.064 2.560000e-02 1.024000e-02 \n",
".. ... ... ... ... ... ... \n",
"395 39.5 39.5 1560.25 61629.875 2.434380e+06 9.615801e+07 \n",
"396 39.6 39.6 1568.16 62099.136 2.459126e+06 9.738138e+07 \n",
"397 39.7 39.7 1576.09 62570.773 2.484060e+06 9.861717e+07 \n",
"398 39.8 39.8 1584.04 63044.792 2.509183e+06 9.986547e+07 \n",
"399 39.9 39.9 1592.01 63521.199 2.534496e+06 1.011264e+08 \n",
"\n",
" d6 d7 d8 d9 \n",
"0 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 \n",
"1 1.000000e-06 1.000000e-07 1.000000e-08 1.000000e-09 \n",
"2 6.400000e-05 1.280000e-05 2.560000e-06 5.120000e-07 \n",
"3 7.290000e-04 2.187000e-04 6.561000e-05 1.968300e-05 \n",
"4 4.096000e-03 1.638400e-03 6.553600e-04 2.621440e-04 \n",
".. ... ... ... ... \n",
"395 3.798241e+09 1.500305e+11 5.926206e+12 2.340851e+14 \n",
"396 3.856303e+09 1.527096e+11 6.047300e+12 2.394731e+14 \n",
"397 3.915102e+09 1.554295e+11 6.170553e+12 2.449709e+14 \n",
"398 3.974646e+09 1.581909e+11 6.295998e+12 2.505807e+14 \n",
"399 4.034943e+09 1.609942e+11 6.423669e+12 2.563044e+14 \n",
"\n",
"[400 rows x 10 columns]"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_plot"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"cell_id": "863cf499f800469680686e991e72f030",
"deepnote_cell_height": 171,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 10,
"execution_start": 1651697221882,
"source_hash": "50d4c81e",
"tags": []
},
"outputs": [],
"source": [
"cols = []\n",
"for deg in range(1,10):\n",
" c = f\"d{deg}\"\n",
" cols.append(c)\n",
" df_train[c] = df_train[\"distance\"]**deg\n",
" df_plot[c] = df_plot[\"distance\"]**deg"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"cell_id": "65f405ad1e2040a08af6e4289f98f6ca",
"deepnote_cell_height": 171,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 5,
"execution_start": 1651697222674,
"source_hash": "1f7244f",
"tags": []
},
"outputs": [],
"source": [
"cols = []\n",
"for deg in range(1,10):\n",
" c = f\"d{deg}\"\n",
" cols.append(c)\n",
" for x in [df_train, df_plot]:\n",
" x[c] = x[\"distance\"]**deg"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"cell_id": "f8ebed454a1f4aa0bc7e4ec6e1db0cee",
"deepnote_cell_height": 600,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 39,
"execution_start": 1651697224096,
"source_hash": "46ed7f7a",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" distance | \n",
" d1 | \n",
" d2 | \n",
" d3 | \n",
" d4 | \n",
" d5 | \n",
" d6 | \n",
" d7 | \n",
" d8 | \n",
" d9 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.00 | \n",
" 0.000 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
" 0.000000e+00 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.1 | \n",
" 0.1 | \n",
" 0.01 | \n",
" 0.001 | \n",
" 1.000000e-04 | \n",
" 1.000000e-05 | \n",
" 1.000000e-06 | \n",
" 1.000000e-07 | \n",
" 1.000000e-08 | \n",
" 1.000000e-09 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.2 | \n",
" 0.2 | \n",
" 0.04 | \n",
" 0.008 | \n",
" 1.600000e-03 | \n",
" 3.200000e-04 | \n",
" 6.400000e-05 | \n",
" 1.280000e-05 | \n",
" 2.560000e-06 | \n",
" 5.120000e-07 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.3 | \n",
" 0.3 | \n",
" 0.09 | \n",
" 0.027 | \n",
" 8.100000e-03 | \n",
" 2.430000e-03 | \n",
" 7.290000e-04 | \n",
" 2.187000e-04 | \n",
" 6.561000e-05 | \n",
" 1.968300e-05 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.4 | \n",
" 0.4 | \n",
" 0.16 | \n",
" 0.064 | \n",
" 2.560000e-02 | \n",
" 1.024000e-02 | \n",
" 4.096000e-03 | \n",
" 1.638400e-03 | \n",
" 6.553600e-04 | \n",
" 2.621440e-04 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 395 | \n",
" 39.5 | \n",
" 39.5 | \n",
" 1560.25 | \n",
" 61629.875 | \n",
" 2.434380e+06 | \n",
" 9.615801e+07 | \n",
" 3.798241e+09 | \n",
" 1.500305e+11 | \n",
" 5.926206e+12 | \n",
" 2.340851e+14 | \n",
"
\n",
" \n",
" 396 | \n",
" 39.6 | \n",
" 39.6 | \n",
" 1568.16 | \n",
" 62099.136 | \n",
" 2.459126e+06 | \n",
" 9.738138e+07 | \n",
" 3.856303e+09 | \n",
" 1.527096e+11 | \n",
" 6.047300e+12 | \n",
" 2.394731e+14 | \n",
"
\n",
" \n",
" 397 | \n",
" 39.7 | \n",
" 39.7 | \n",
" 1576.09 | \n",
" 62570.773 | \n",
" 2.484060e+06 | \n",
" 9.861717e+07 | \n",
" 3.915102e+09 | \n",
" 1.554295e+11 | \n",
" 6.170553e+12 | \n",
" 2.449709e+14 | \n",
"
\n",
" \n",
" 398 | \n",
" 39.8 | \n",
" 39.8 | \n",
" 1584.04 | \n",
" 63044.792 | \n",
" 2.509183e+06 | \n",
" 9.986547e+07 | \n",
" 3.974646e+09 | \n",
" 1.581909e+11 | \n",
" 6.295998e+12 | \n",
" 2.505807e+14 | \n",
"
\n",
" \n",
" 399 | \n",
" 39.9 | \n",
" 39.9 | \n",
" 1592.01 | \n",
" 63521.199 | \n",
" 2.534496e+06 | \n",
" 1.011264e+08 | \n",
" 4.034943e+09 | \n",
" 1.609942e+11 | \n",
" 6.423669e+12 | \n",
" 2.563044e+14 | \n",
"
\n",
" \n",
"
\n",
"
400 rows × 10 columns
\n",
"
"
],
"text/plain": [
" distance d1 d2 d3 d4 d5 \\\n",
"0 0.0 0.0 0.00 0.000 0.000000e+00 0.000000e+00 \n",
"1 0.1 0.1 0.01 0.001 1.000000e-04 1.000000e-05 \n",
"2 0.2 0.2 0.04 0.008 1.600000e-03 3.200000e-04 \n",
"3 0.3 0.3 0.09 0.027 8.100000e-03 2.430000e-03 \n",
"4 0.4 0.4 0.16 0.064 2.560000e-02 1.024000e-02 \n",
".. ... ... ... ... ... ... \n",
"395 39.5 39.5 1560.25 61629.875 2.434380e+06 9.615801e+07 \n",
"396 39.6 39.6 1568.16 62099.136 2.459126e+06 9.738138e+07 \n",
"397 39.7 39.7 1576.09 62570.773 2.484060e+06 9.861717e+07 \n",
"398 39.8 39.8 1584.04 63044.792 2.509183e+06 9.986547e+07 \n",
"399 39.9 39.9 1592.01 63521.199 2.534496e+06 1.011264e+08 \n",
"\n",
" d6 d7 d8 d9 \n",
"0 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 \n",
"1 1.000000e-06 1.000000e-07 1.000000e-08 1.000000e-09 \n",
"2 6.400000e-05 1.280000e-05 2.560000e-06 5.120000e-07 \n",
"3 7.290000e-04 2.187000e-04 6.561000e-05 1.968300e-05 \n",
"4 4.096000e-03 1.638400e-03 6.553600e-04 2.621440e-04 \n",
".. ... ... ... ... \n",
"395 3.798241e+09 1.500305e+11 5.926206e+12 2.340851e+14 \n",
"396 3.856303e+09 1.527096e+11 6.047300e+12 2.394731e+14 \n",
"397 3.915102e+09 1.554295e+11 6.170553e+12 2.449709e+14 \n",
"398 3.974646e+09 1.581909e+11 6.295998e+12 2.505807e+14 \n",
"399 4.034943e+09 1.609942e+11 6.423669e+12 2.563044e+14 \n",
"\n",
"[400 rows x 10 columns]"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_plot"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"cell_id": "a917f06f07bb400494dce68934a45302",
"deepnote_cell_height": 118.1875,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
21.1875
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 9,
"execution_start": 1651697225637,
"source_hash": "17d627e1",
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"Index(['distance', 'd1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9'], dtype='object')"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_plot.columns"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"cell_id": "2a8f7a1f2c7c4cb8bf59fe74d6bc654a",
"deepnote_cell_height": 81,
"deepnote_cell_type": "code",
"deepnote_to_be_reexecuted": false,
"execution_millis": 10,
"execution_start": 1651697226762,
"source_hash": "fc6b2a0f",
"tags": []
},
"outputs": [],
"source": [
"df_plot[\"Pred\"] = reg.predict(df_plot[cols])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following looks pretty bad, because the scale of the fitted polynomial is so much larger, so we can't see any detail in the curve.\n",
"\n",
"Also notice that it doesn't make sense. Our model is predicting that, when the taxi ride is 40 miles, the cost should be approximately negative two billion dollars."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"cell_id": "1c5c73c0341946719b4ca3635f1a71b8",
"deepnote_cell_height": 638,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
361
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 117,
"execution_start": 1651697226778,
"source_hash": "cc77815",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
""
],
"text/plain": [
"alt.LayerChart(...)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = alt.Chart(df_train).mark_circle().encode(\n",
" x=\"distance\",\n",
" y=\"total\"\n",
")\n",
"\n",
"c9 = alt.Chart(df_plot).mark_line(color=\"red\").encode(\n",
" x=\"distance\",\n",
" y=\"Pred\"\n",
")\n",
"\n",
"c+c9"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will fix this by doing two things: specifying the domain for the y-axis, and specifying `clip=True` (that is important) when calling `mark_line()`. The `clip=True` says to get rid of the points that are outside the axes ranges."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following is a clear example of overfitting. Our polynomial has too much flexibility (the 9 coefficients) relative to the number of data points (40 in this case). For example, this flexibility allows the polynomial to nearly exactly pass through many of the data points."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"cell_id": "18a9c39462ab450a81508e1368a6706c",
"deepnote_cell_height": 638,
"deepnote_cell_type": "code",
"deepnote_output_heights": [
361
],
"deepnote_to_be_reexecuted": false,
"execution_millis": 92,
"execution_start": 1651697230180,
"source_hash": "6e2db709",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
""
],
"text/plain": [
"alt.LayerChart(...)"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = alt.Chart(df_train).mark_circle().encode(\n",
" x=\"distance\",\n",
" y=\"total\"\n",
")\n",
"\n",
"c9 = alt.Chart(df_plot).mark_line(color=\"red\", clip=True).encode(\n",
" x=\"distance\",\n",
" y=alt.Y(\"Pred\", scale=alt.Scale(domain=(0,200)))\n",
")\n",
"\n",
"c+c9"
]
}
],
"metadata": {
"deepnote": {},
"deepnote_execution_queue": [],
"deepnote_notebook_id": "a77823bb-2b5b-436d-9a2e-689e1064d274",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}