{ "cells": [ { "cell_type": "markdown", "metadata": { "cell_id": "e4e4c69a-c97b-49fb-8c83-2a5b1d3471da", "deepnote_cell_height": 224, "deepnote_cell_type": "markdown", "tags": [] }, "source": [ "# Polynomial Regression 2\n", "\n", "I think the most important concept in machine learning is the concept of **overfitting**. The idea is if you have too flexible of a model (relative to the number of data points), then your model can match the data very closely, but it will do a poor job of predicting future values.\n", "\n", "When performing polynomial regression, the higher the degree of the polynomial, the more flexible the model is. So when the degree of the polynomial is higher, there is a greater risk of overfitting." ] }, { "cell_type": "markdown", "metadata": { "cell_id": "2f6571ef76ce4b94b0d895a99a3fa62c", "deepnote_cell_height": 70, "deepnote_cell_type": "markdown", "tags": [] }, "source": [ "## The taxis dataset\n", "\n", "The taxis dataset contains information about approximately 6000 taxi rides in New York City. Our goal is to try to model the total cost of the taxi ride using the distance." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cell_id": "90b5eedfa9f141a1bd52062da5e353cb", "deepnote_cell_height": 135, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 5065, "execution_start": 1651694170552, "source_hash": "64a83654", "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "import altair as alt\n", "alt.data_transformers.enable('default', max_rows=10000)\n", "import seaborn as sns" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cell_id": "4a2afe954c8f42d6a3ceb8c0462b216b", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 431, "execution_start": 1651694181207, "source_hash": "eaa10c21", "tags": [] }, "outputs": [], "source": [ "df = sns.load_dataset(\"taxis\").dropna()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cell_id": "220efa9a0ebe4327b3a1b13209447276", "deepnote_cell_height": 395, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 134, "execution_start": 1651694539556, "source_hash": "c085b6ba", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pickupdropoffpassengersdistancefaretiptollstotalcolorpaymentpickup_zonedropoff_zonepickup_boroughdropoff_borough
02019-03-23 20:21:092019-03-23 20:27:2411.607.02.150.012.95yellowcredit cardLenox Hill WestUN/Turtle Bay SouthManhattanManhattan
12019-03-04 16:11:552019-03-04 16:19:0010.795.00.000.09.30yellowcashUpper West Side SouthUpper West Side SouthManhattanManhattan
22019-03-27 17:53:012019-03-27 18:00:2511.377.52.360.014.16yellowcredit cardAlphabet CityWest VillageManhattanManhattan
32019-03-10 01:23:592019-03-10 01:49:5117.7027.06.150.036.95yellowcredit cardHudson SqYorkville WestManhattanManhattan
42019-03-30 13:27:422019-03-30 13:37:1432.169.01.100.013.40yellowcredit cardMidtown EastYorkville WestManhattanManhattan
\n", "
" ], "text/plain": [ " pickup dropoff passengers distance fare tip \\\n", "0 2019-03-23 20:21:09 2019-03-23 20:27:24 1 1.60 7.0 2.15 \n", "1 2019-03-04 16:11:55 2019-03-04 16:19:00 1 0.79 5.0 0.00 \n", "2 2019-03-27 17:53:01 2019-03-27 18:00:25 1 1.37 7.5 2.36 \n", "3 2019-03-10 01:23:59 2019-03-10 01:49:51 1 7.70 27.0 6.15 \n", "4 2019-03-30 13:27:42 2019-03-30 13:37:14 3 2.16 9.0 1.10 \n", "\n", " tolls total color payment pickup_zone \\\n", "0 0.0 12.95 yellow credit card Lenox Hill West \n", "1 0.0 9.30 yellow cash Upper West Side South \n", "2 0.0 14.16 yellow credit card Alphabet City \n", "3 0.0 36.95 yellow credit card Hudson Sq \n", "4 0.0 13.40 yellow credit card Midtown East \n", "\n", " dropoff_zone pickup_borough dropoff_borough \n", "0 UN/Turtle Bay South Manhattan Manhattan \n", "1 Upper West Side South Manhattan Manhattan \n", "2 West Village Manhattan Manhattan \n", "3 Yorkville West Manhattan Manhattan \n", "4 Yorkville West Manhattan Manhattan " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following plot shows that the data follows a linear model very closely. We are going to instead try using polynomial regression with a high degree (degree 9 in this case)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cell_id": "94eff324ca1448cb9b9a5cb5cbf6396f", "deepnote_cell_height": 512, "deepnote_cell_type": "code", "deepnote_output_heights": [ 361 ], "deepnote_to_be_reexecuted": false, "execution_millis": 1017, "execution_start": 1651694627624, "source_hash": "ede7be0d", "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "alt.Chart(df).mark_circle().encode(\n", " x=\"distance\",\n", " y=\"total\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we fit a degree 9 polynomial to all 6000+ data points, it will look very good. Instead we will only use 40 of the data points. In general, the fewer the data points you use, the more risk there is for overfitting." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "cell_id": "ce9fa8e7091f45a98aec0c15f1d58d5c", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 6, "execution_start": 1651694778361, "source_hash": "746a4dbc", "tags": [] }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following is *not* the usual way to call `train_test_split`. Below we will use the usual way, which involves list unpacking.\n", "\n", "Here `train_size=40` says we want to choose 40 random rows from the DataFrame `df`. If we had instead used `train_size=0.4`, that would say we want to choose 40% of the rows." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "cell_id": "2c9de1ad14144e8aa06eccb2063508a2", "deepnote_cell_height": 99, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 6, "execution_start": 1651694850545, "source_hash": "eb474865", "tags": [] }, "outputs": [], "source": [ "# use 40 data points\n", "a = train_test_split(df, train_size=40)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "cell_id": "9e8eb4fae0e94e5aa4bbba7cbb6489e1", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 4, "execution_start": 1651694865156, "source_hash": "1478ef", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "list" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(a)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "cell_id": "a3d6df3a7730438fa11669f92061a507", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 11, "execution_start": 1651694877667, "source_hash": "f64cd242", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "2" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(a)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "cell_id": "fa81cbe73a274b7589e49352de29782a", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 6, "execution_start": 1651694907861, "source_hash": "704ff305", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "pandas.core.frame.DataFrame" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(a[0])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "cell_id": "d8713820e1794d1d8af5e6ac524146e9", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 9, "execution_start": 1651694918308, "source_hash": "2abccaa2", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(40, 14)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[0].shape" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "cell_id": "8ffd16afd5f04e30a756acebe7fbb678", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 17, "execution_start": 1651694930593, "source_hash": "bdee13fc", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(6301, 14)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[1].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Currently `a[0]` contains 40 rows from the DataFrame, and `a[1]` contains the remaining 6301 rows." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "cell_id": "9141f9534a08453fb2275e0f7debb008", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 11, "execution_start": 1651694935057, "source_hash": "14f60b8f", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(6341, 14)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The rows are chosen at random. (They aren't even presented in order.)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "cell_id": "67e4e225eb8246a2b27d765d5c4fabe4", "deepnote_cell_height": 395, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 62, "execution_start": 1651694995970, "source_hash": "b62c9399", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pickupdropoffpassengersdistancefaretiptollstotalcolorpaymentpickup_zonedropoff_zonepickup_boroughdropoff_borough
18552019-03-29 14:20:182019-03-29 14:36:3562.7212.53.160.018.96yellowcredit cardHudson SqMurray HillManhattanManhattan
50292019-03-21 06:51:402019-03-21 07:07:1113.3914.02.750.020.05yellowcredit cardCentral ParkTimes Sq/Theatre DistrictManhattanManhattan
52772019-03-19 04:54:042019-03-19 05:00:0521.657.52.000.013.30yellowcredit cardEast ChelseaClinton EastManhattanManhattan
2472019-03-07 18:49:412019-03-07 19:07:0512.7513.03.460.020.76yellowcredit cardYorkville WestLincoln Square EastManhattanManhattan
9122019-03-15 23:13:442019-03-15 23:24:4711.709.52.650.015.95yellowcredit cardLincoln Square WestEast ChelseaManhattanManhattan
\n", "
" ], "text/plain": [ " pickup dropoff passengers distance fare \\\n", "1855 2019-03-29 14:20:18 2019-03-29 14:36:35 6 2.72 12.5 \n", "5029 2019-03-21 06:51:40 2019-03-21 07:07:11 1 3.39 14.0 \n", "5277 2019-03-19 04:54:04 2019-03-19 05:00:05 2 1.65 7.5 \n", "247 2019-03-07 18:49:41 2019-03-07 19:07:05 1 2.75 13.0 \n", "912 2019-03-15 23:13:44 2019-03-15 23:24:47 1 1.70 9.5 \n", "\n", " tip tolls total color payment pickup_zone \\\n", "1855 3.16 0.0 18.96 yellow credit card Hudson Sq \n", "5029 2.75 0.0 20.05 yellow credit card Central Park \n", "5277 2.00 0.0 13.30 yellow credit card East Chelsea \n", "247 3.46 0.0 20.76 yellow credit card Yorkville West \n", "912 2.65 0.0 15.95 yellow credit card Lincoln Square West \n", "\n", " dropoff_zone pickup_borough dropoff_borough \n", "1855 Murray Hill Manhattan Manhattan \n", "5029 Times Sq/Theatre District Manhattan Manhattan \n", "5277 Clinton East Manhattan Manhattan \n", "247 Lincoln Square East Manhattan Manhattan \n", "912 East Chelsea Manhattan Manhattan " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[0].head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The usual way to call `train_test_split` is to use list unpacking, which is what we do in the following code." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "cell_id": "2c3d826550ed4a828eb7476c1458e8aa", "deepnote_cell_height": 99, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 6, "execution_start": 1651695117640, "source_hash": "cf8362a5", "tags": [] }, "outputs": [], "source": [ "# list unpacking\n", "df_train, df_test = train_test_split(df, train_size=40)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now `df_train` will contain a different set of 40 rows from `df`." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "cell_id": "a1432e5d1bb94274a5d02f0c1e47d660", "deepnote_cell_height": 395, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 60, "execution_start": 1651695119730, "source_hash": "5e2a9c26", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pickupdropoffpassengersdistancefaretiptollstotalcolorpaymentpickup_zonedropoff_zonepickup_boroughdropoff_borough
10912019-03-05 20:17:312019-03-05 20:35:2114.2016.00.00.019.8yellowcashPenn Station/Madison Sq WestBattery Park CityManhattanManhattan
54592019-03-29 16:25:222019-03-29 17:15:57112.7639.50.00.041.3greencredit cardBrighton BeachRichmond HillBrooklynQueens
41642019-03-20 09:11:062019-03-20 09:21:4420.838.00.00.011.3yellowcashMidtown SouthMidtown EastManhattanManhattan
8192019-03-17 19:30:192019-03-17 19:41:3912.109.50.00.012.8yellowcashLincoln Square EastUpper West Side NorthManhattanManhattan
61362019-03-05 19:01:212019-03-05 19:05:5610.875.00.00.06.8greencashElmhurstElmhurstQueensQueens
\n", "
" ], "text/plain": [ " pickup dropoff passengers distance fare \\\n", "1091 2019-03-05 20:17:31 2019-03-05 20:35:21 1 4.20 16.0 \n", "5459 2019-03-29 16:25:22 2019-03-29 17:15:57 1 12.76 39.5 \n", "4164 2019-03-20 09:11:06 2019-03-20 09:21:44 2 0.83 8.0 \n", "819 2019-03-17 19:30:19 2019-03-17 19:41:39 1 2.10 9.5 \n", "6136 2019-03-05 19:01:21 2019-03-05 19:05:56 1 0.87 5.0 \n", "\n", " tip tolls total color payment pickup_zone \\\n", "1091 0.0 0.0 19.8 yellow cash Penn Station/Madison Sq West \n", "5459 0.0 0.0 41.3 green credit card Brighton Beach \n", "4164 0.0 0.0 11.3 yellow cash Midtown South \n", "819 0.0 0.0 12.8 yellow cash Lincoln Square East \n", "6136 0.0 0.0 6.8 green cash Elmhurst \n", "\n", " dropoff_zone pickup_borough dropoff_borough \n", "1091 Battery Park City Manhattan Manhattan \n", "5459 Richmond Hill Brooklyn Queens \n", "4164 Midtown East Manhattan Manhattan \n", "819 Upper West Side North Manhattan Manhattan \n", "6136 Elmhurst Queens Queens " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "cell_id": "c0406a4f79af4bd3a80d328ce4c5f9ff", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 170, "execution_start": 1651695196957, "source_hash": "db981d08", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(40, 14)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.shape" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "fdbe95819f9646e8bb067705e18c28e5", "deepnote_cell_height": 70, "deepnote_cell_type": "markdown", "tags": [] }, "source": [ "## Fitting polynomial regression" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "cell_id": "65ef0388c69a4933b65e75ba6936d998", "deepnote_cell_height": 135, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 16, "execution_start": 1651695572487, "source_hash": "a285ad9", "tags": [] }, "outputs": [], "source": [ "cols = []\n", "for deg in range(1,10):\n", " cols.append(f\"d{deg}\")\n", " df_train[f\"d{deg}\"] = df_train[\"distance\"]**deg" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is a more DRY approach, where we only have to type `f\"d{deg}\"` once." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "cell_id": "ec9358bd85a2452f9736696123a43ba8", "deepnote_cell_height": 153, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 6, "execution_start": 1651695621654, "source_hash": "6dc14bfc", "tags": [] }, "outputs": [], "source": [ "cols = []\n", "for deg in range(1,10):\n", " c = f\"d{deg}\"\n", " cols.append(c)\n", " df_train[c] = df_train[\"distance\"]**deg" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "cell_id": "37924b8fa2e141abb05c0f0fee4bb327", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 7, "execution_start": 1651695624088, "source_hash": "cb73c151", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "['d1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9']" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice how 9 rows have been added to `df_train`. They contain the 1st, 2nd, ..., 9th powers of the values in the \"distance\" column." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "cell_id": "03a6321f99094919b5be5bda6e321a34", "deepnote_cell_height": 600, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 179, "execution_start": 1651695444130, "source_hash": "e7a4a97c", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pickupdropoffpassengersdistancefaretiptollstotalcolorpayment...dropoff_boroughd1d2d3d4d5d6d7d8d9
10912019-03-05 20:17:312019-03-05 20:35:2114.2016.00.00.019.8yellowcash...Manhattan4.2017.640074.088000311.1696001306.9123205.489032e+032.305393e+049.682652e+044.066714e+05
54592019-03-29 16:25:222019-03-29 17:15:57112.7639.50.00.041.3greencredit card...Queens12.76162.81762077.55257626509.570870338262.1242984.316225e+065.507503e+077.027573e+088.967184e+09
41642019-03-20 09:11:062019-03-20 09:21:4420.838.00.00.011.3yellowcash...Manhattan0.830.68890.5717870.4745830.3939043.269404e-012.713605e-012.252292e-011.869403e-01
8192019-03-17 19:30:192019-03-17 19:41:3912.109.50.00.012.8yellowcash...Manhattan2.104.41009.26100019.44810040.8410108.576612e+011.801089e+023.782286e+027.942800e+02
61362019-03-05 19:01:212019-03-05 19:05:5610.875.00.00.06.8greencash...Queens0.870.75690.6585030.5728980.4984214.336262e-013.772548e-013.282117e-012.855442e-01
\n", "

5 rows × 23 columns

\n", "
" ], "text/plain": [ " pickup dropoff passengers distance fare \\\n", "1091 2019-03-05 20:17:31 2019-03-05 20:35:21 1 4.20 16.0 \n", "5459 2019-03-29 16:25:22 2019-03-29 17:15:57 1 12.76 39.5 \n", "4164 2019-03-20 09:11:06 2019-03-20 09:21:44 2 0.83 8.0 \n", "819 2019-03-17 19:30:19 2019-03-17 19:41:39 1 2.10 9.5 \n", "6136 2019-03-05 19:01:21 2019-03-05 19:05:56 1 0.87 5.0 \n", "\n", " tip tolls total color payment ... dropoff_borough d1 \\\n", "1091 0.0 0.0 19.8 yellow cash ... Manhattan 4.20 \n", "5459 0.0 0.0 41.3 green credit card ... Queens 12.76 \n", "4164 0.0 0.0 11.3 yellow cash ... Manhattan 0.83 \n", "819 0.0 0.0 12.8 yellow cash ... Manhattan 2.10 \n", "6136 0.0 0.0 6.8 green cash ... Queens 0.87 \n", "\n", " d2 d3 d4 d5 d6 \\\n", "1091 17.6400 74.088000 311.169600 1306.912320 5.489032e+03 \n", "5459 162.8176 2077.552576 26509.570870 338262.124298 4.316225e+06 \n", "4164 0.6889 0.571787 0.474583 0.393904 3.269404e-01 \n", "819 4.4100 9.261000 19.448100 40.841010 8.576612e+01 \n", "6136 0.7569 0.658503 0.572898 0.498421 4.336262e-01 \n", "\n", " d7 d8 d9 \n", "1091 2.305393e+04 9.682652e+04 4.066714e+05 \n", "5459 5.507503e+07 7.027573e+08 8.967184e+09 \n", "4164 2.713605e-01 2.252292e-01 1.869403e-01 \n", "819 1.801089e+02 3.782286e+02 7.942800e+02 \n", "6136 3.772548e-01 3.282117e-01 2.855442e-01 \n", "\n", "[5 rows x 23 columns]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Remember, as we saw last time, polynomial regression can be viewed as a special case of linear regression. (The reverse is also true and more obvious.)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "cell_id": "18595bda00f4433d9c2b54c31c6df588", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 0, "execution_start": 1651695714609, "source_hash": "9527aab5", "tags": [] }, "outputs": [], "source": [ "from sklearn.linear_model import LinearRegression" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "cell_id": "0425f7bee2e6480ca26be1b0d37ad389", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 1, "execution_start": 1651695731464, "source_hash": "cd72858e", "tags": [] }, "outputs": [], "source": [ "reg = LinearRegression()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is important to only fit on `df_train`, because we want to only use 40 data points, so that we can demonstrate overfitting more effectively." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "cell_id": "67ac2d89e0304659b4c6ea783d14502e", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 13, "execution_start": 1651695879050, "source_hash": "40287825", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "LinearRegression()" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg.fit(df_train[cols], df_train[\"total\"])" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "cell_id": "c188ddcdc9074691a8521537fc2fdad8", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 0, "execution_start": 1651695926138, "source_hash": "97a9514a", "tags": [] }, "outputs": [], "source": [ "df_train[\"Pred\"] = reg.predict(df_train[cols])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The red line in the following shows the polynomial. It doesn't look like a polynomial, because Altair (like Matlab and like Matplotlib) connects the points with straight lines. Below we will make a DataFrame with more input values, so that the polynomial will look smoother. " ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "cell_id": "ca8ca92edd8b48f5951213896094eff8", "deepnote_cell_height": 633, "deepnote_cell_type": "code", "deepnote_output_heights": [ 356 ], "deepnote_to_be_reexecuted": false, "execution_millis": 103, "execution_start": 1651696050196, "source_hash": "d9da7148", "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.LayerChart(...)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = alt.Chart(df_train).mark_circle().encode(\n", " x=\"distance\",\n", " y=\"total\"\n", ")\n", "\n", "c9 = alt.Chart(df_train).mark_line(color=\"red\").encode(\n", " x=\"distance\",\n", " y=\"Pred\"\n", ")\n", "\n", "c+c9" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "cell_id": "4970705ce2554ae78dd8758fa52ada9e", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 8, "execution_start": 1651696300546, "source_hash": "c2602aa8", "tags": [] }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "cell_id": "20d17375fb7546518bdb1ab22268030e", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 4, "execution_start": 1651697263223, "source_hash": "b4f635bd", "tags": [] }, "outputs": [], "source": [ "df_plot = pd.DataFrame({\"distance\": np.arange(0,40,0.1)})" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "cell_id": "e810541fbdc24b4b81f25415543c9122", "deepnote_cell_height": 600, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 695, "execution_start": 1651697217800, "source_hash": "46ed7f7a", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
distance
00.0
10.1
20.2
30.3
40.4
......
39539.5
39639.6
39739.7
39839.8
39939.9
\n", "

400 rows × 1 columns

\n", "
" ], "text/plain": [ " distance\n", "0 0.0\n", "1 0.1\n", "2 0.2\n", "3 0.3\n", "4 0.4\n", ".. ...\n", "395 39.5\n", "396 39.6\n", "397 39.7\n", "398 39.8\n", "399 39.9\n", "\n", "[400 rows x 1 columns]" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_plot" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "cell_id": "7b809031021648249deb4b653e87d202", "deepnote_cell_height": 171, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 19, "execution_start": 1651697219353, "source_hash": "50d4c81e", "tags": [] }, "outputs": [], "source": [ "cols = []\n", "for deg in range(1,10):\n", " c = f\"d{deg}\"\n", " cols.append(c)\n", " df_train[c] = df_train[\"distance\"]**deg\n", " df_plot[c] = df_plot[\"distance\"]**deg" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "cell_id": "2f4fff32593646ffbe5bf69359d74e7c", "deepnote_cell_height": 600, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 76, "execution_start": 1651697220201, "source_hash": "46ed7f7a", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
distanced1d2d3d4d5d6d7d8d9
00.00.00.000.0000.000000e+000.000000e+000.000000e+000.000000e+000.000000e+000.000000e+00
10.10.10.010.0011.000000e-041.000000e-051.000000e-061.000000e-071.000000e-081.000000e-09
20.20.20.040.0081.600000e-033.200000e-046.400000e-051.280000e-052.560000e-065.120000e-07
30.30.30.090.0278.100000e-032.430000e-037.290000e-042.187000e-046.561000e-051.968300e-05
40.40.40.160.0642.560000e-021.024000e-024.096000e-031.638400e-036.553600e-042.621440e-04
.................................
39539.539.51560.2561629.8752.434380e+069.615801e+073.798241e+091.500305e+115.926206e+122.340851e+14
39639.639.61568.1662099.1362.459126e+069.738138e+073.856303e+091.527096e+116.047300e+122.394731e+14
39739.739.71576.0962570.7732.484060e+069.861717e+073.915102e+091.554295e+116.170553e+122.449709e+14
39839.839.81584.0463044.7922.509183e+069.986547e+073.974646e+091.581909e+116.295998e+122.505807e+14
39939.939.91592.0163521.1992.534496e+061.011264e+084.034943e+091.609942e+116.423669e+122.563044e+14
\n", "

400 rows × 10 columns

\n", "
" ], "text/plain": [ " distance d1 d2 d3 d4 d5 \\\n", "0 0.0 0.0 0.00 0.000 0.000000e+00 0.000000e+00 \n", "1 0.1 0.1 0.01 0.001 1.000000e-04 1.000000e-05 \n", "2 0.2 0.2 0.04 0.008 1.600000e-03 3.200000e-04 \n", "3 0.3 0.3 0.09 0.027 8.100000e-03 2.430000e-03 \n", "4 0.4 0.4 0.16 0.064 2.560000e-02 1.024000e-02 \n", ".. ... ... ... ... ... ... \n", "395 39.5 39.5 1560.25 61629.875 2.434380e+06 9.615801e+07 \n", "396 39.6 39.6 1568.16 62099.136 2.459126e+06 9.738138e+07 \n", "397 39.7 39.7 1576.09 62570.773 2.484060e+06 9.861717e+07 \n", "398 39.8 39.8 1584.04 63044.792 2.509183e+06 9.986547e+07 \n", "399 39.9 39.9 1592.01 63521.199 2.534496e+06 1.011264e+08 \n", "\n", " d6 d7 d8 d9 \n", "0 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 \n", "1 1.000000e-06 1.000000e-07 1.000000e-08 1.000000e-09 \n", "2 6.400000e-05 1.280000e-05 2.560000e-06 5.120000e-07 \n", "3 7.290000e-04 2.187000e-04 6.561000e-05 1.968300e-05 \n", "4 4.096000e-03 1.638400e-03 6.553600e-04 2.621440e-04 \n", ".. ... ... ... ... \n", "395 3.798241e+09 1.500305e+11 5.926206e+12 2.340851e+14 \n", "396 3.856303e+09 1.527096e+11 6.047300e+12 2.394731e+14 \n", "397 3.915102e+09 1.554295e+11 6.170553e+12 2.449709e+14 \n", "398 3.974646e+09 1.581909e+11 6.295998e+12 2.505807e+14 \n", "399 4.034943e+09 1.609942e+11 6.423669e+12 2.563044e+14 \n", "\n", "[400 rows x 10 columns]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_plot" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "cell_id": "863cf499f800469680686e991e72f030", "deepnote_cell_height": 171, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 10, "execution_start": 1651697221882, "source_hash": "50d4c81e", "tags": [] }, "outputs": [], "source": [ "cols = []\n", "for deg in range(1,10):\n", " c = f\"d{deg}\"\n", " cols.append(c)\n", " df_train[c] = df_train[\"distance\"]**deg\n", " df_plot[c] = df_plot[\"distance\"]**deg" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "cell_id": "65f405ad1e2040a08af6e4289f98f6ca", "deepnote_cell_height": 171, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 5, "execution_start": 1651697222674, "source_hash": "1f7244f", "tags": [] }, "outputs": [], "source": [ "cols = []\n", "for deg in range(1,10):\n", " c = f\"d{deg}\"\n", " cols.append(c)\n", " for x in [df_train, df_plot]:\n", " x[c] = x[\"distance\"]**deg" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "cell_id": "f8ebed454a1f4aa0bc7e4ec6e1db0cee", "deepnote_cell_height": 600, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 39, "execution_start": 1651697224096, "source_hash": "46ed7f7a", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
distanced1d2d3d4d5d6d7d8d9
00.00.00.000.0000.000000e+000.000000e+000.000000e+000.000000e+000.000000e+000.000000e+00
10.10.10.010.0011.000000e-041.000000e-051.000000e-061.000000e-071.000000e-081.000000e-09
20.20.20.040.0081.600000e-033.200000e-046.400000e-051.280000e-052.560000e-065.120000e-07
30.30.30.090.0278.100000e-032.430000e-037.290000e-042.187000e-046.561000e-051.968300e-05
40.40.40.160.0642.560000e-021.024000e-024.096000e-031.638400e-036.553600e-042.621440e-04
.................................
39539.539.51560.2561629.8752.434380e+069.615801e+073.798241e+091.500305e+115.926206e+122.340851e+14
39639.639.61568.1662099.1362.459126e+069.738138e+073.856303e+091.527096e+116.047300e+122.394731e+14
39739.739.71576.0962570.7732.484060e+069.861717e+073.915102e+091.554295e+116.170553e+122.449709e+14
39839.839.81584.0463044.7922.509183e+069.986547e+073.974646e+091.581909e+116.295998e+122.505807e+14
39939.939.91592.0163521.1992.534496e+061.011264e+084.034943e+091.609942e+116.423669e+122.563044e+14
\n", "

400 rows × 10 columns

\n", "
" ], "text/plain": [ " distance d1 d2 d3 d4 d5 \\\n", "0 0.0 0.0 0.00 0.000 0.000000e+00 0.000000e+00 \n", "1 0.1 0.1 0.01 0.001 1.000000e-04 1.000000e-05 \n", "2 0.2 0.2 0.04 0.008 1.600000e-03 3.200000e-04 \n", "3 0.3 0.3 0.09 0.027 8.100000e-03 2.430000e-03 \n", "4 0.4 0.4 0.16 0.064 2.560000e-02 1.024000e-02 \n", ".. ... ... ... ... ... ... \n", "395 39.5 39.5 1560.25 61629.875 2.434380e+06 9.615801e+07 \n", "396 39.6 39.6 1568.16 62099.136 2.459126e+06 9.738138e+07 \n", "397 39.7 39.7 1576.09 62570.773 2.484060e+06 9.861717e+07 \n", "398 39.8 39.8 1584.04 63044.792 2.509183e+06 9.986547e+07 \n", "399 39.9 39.9 1592.01 63521.199 2.534496e+06 1.011264e+08 \n", "\n", " d6 d7 d8 d9 \n", "0 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 \n", "1 1.000000e-06 1.000000e-07 1.000000e-08 1.000000e-09 \n", "2 6.400000e-05 1.280000e-05 2.560000e-06 5.120000e-07 \n", "3 7.290000e-04 2.187000e-04 6.561000e-05 1.968300e-05 \n", "4 4.096000e-03 1.638400e-03 6.553600e-04 2.621440e-04 \n", ".. ... ... ... ... \n", "395 3.798241e+09 1.500305e+11 5.926206e+12 2.340851e+14 \n", "396 3.856303e+09 1.527096e+11 6.047300e+12 2.394731e+14 \n", "397 3.915102e+09 1.554295e+11 6.170553e+12 2.449709e+14 \n", "398 3.974646e+09 1.581909e+11 6.295998e+12 2.505807e+14 \n", "399 4.034943e+09 1.609942e+11 6.423669e+12 2.563044e+14 \n", "\n", "[400 rows x 10 columns]" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_plot" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "cell_id": "a917f06f07bb400494dce68934a45302", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 9, "execution_start": 1651697225637, "source_hash": "17d627e1", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Index(['distance', 'd1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9'], dtype='object')" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_plot.columns" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "cell_id": "2a8f7a1f2c7c4cb8bf59fe74d6bc654a", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 10, "execution_start": 1651697226762, "source_hash": "fc6b2a0f", "tags": [] }, "outputs": [], "source": [ "df_plot[\"Pred\"] = reg.predict(df_plot[cols])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following looks pretty bad, because the scale of the fitted polynomial is so much larger, so we can't see any detail in the curve.\n", "\n", "Also notice that it doesn't make sense. Our model is predicting that, when the taxi ride is 40 miles, the cost should be approximately negative two billion dollars." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "cell_id": "1c5c73c0341946719b4ca3635f1a71b8", "deepnote_cell_height": 638, "deepnote_cell_type": "code", "deepnote_output_heights": [ 361 ], "deepnote_to_be_reexecuted": false, "execution_millis": 117, "execution_start": 1651697226778, "source_hash": "cc77815", "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.LayerChart(...)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = alt.Chart(df_train).mark_circle().encode(\n", " x=\"distance\",\n", " y=\"total\"\n", ")\n", "\n", "c9 = alt.Chart(df_plot).mark_line(color=\"red\").encode(\n", " x=\"distance\",\n", " y=\"Pred\"\n", ")\n", "\n", "c+c9" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will fix this by doing two things: specifying the domain for the y-axis, and specifying `clip=True` (that is important) when calling `mark_line()`. The `clip=True` says to get rid of the points that are outside the axes ranges." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following is a clear example of overfitting. Our polynomial has too much flexibility (the 9 coefficients) relative to the number of data points (40 in this case). For example, this flexibility allows the polynomial to nearly exactly pass through many of the data points." ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "cell_id": "18a9c39462ab450a81508e1368a6706c", "deepnote_cell_height": 638, "deepnote_cell_type": "code", "deepnote_output_heights": [ 361 ], "deepnote_to_be_reexecuted": false, "execution_millis": 92, "execution_start": 1651697230180, "source_hash": "6e2db709", "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.LayerChart(...)" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = alt.Chart(df_train).mark_circle().encode(\n", " x=\"distance\",\n", " y=\"total\"\n", ")\n", "\n", "c9 = alt.Chart(df_plot).mark_line(color=\"red\", clip=True).encode(\n", " x=\"distance\",\n", " y=alt.Y(\"Pred\", scale=alt.Scale(domain=(0,200)))\n", ")\n", "\n", "c+c9" ] } ], "metadata": { "deepnote": {}, "deepnote_execution_queue": [], "deepnote_notebook_id": "a77823bb-2b5b-436d-9a2e-689e1064d274", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 4 }