{ "cells": [ { "cell_type": "markdown", "metadata": { "cell_id": "e4e4c69a-c97b-49fb-8c83-2a5b1d3471da", "deepnote_cell_height": 134, "deepnote_cell_type": "markdown", "tags": [] }, "source": [ "# Linear and Polynomial Regression with the taxis dataset" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cell_id": "90b5eedfa9f141a1bd52062da5e353cb", "deepnote_cell_height": 135, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 3177, "execution_start": 1651867080972, "source_hash": "64a83654", "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "import altair as alt\n", "alt.data_transformers.enable('default', max_rows=10000)\n", "import seaborn as sns" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cell_id": "4a2afe954c8f42d6a3ceb8c0462b216b", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 383, "execution_start": 1651867148401, "source_hash": "eaa10c21", "tags": [] }, "outputs": [], "source": [ "df = sns.load_dataset(\"taxis\").dropna()" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "2f6571ef76ce4b94b0d895a99a3fa62c", "deepnote_cell_height": 144.78125, "deepnote_cell_type": "markdown", "tags": [] }, "source": [ "## Linear regression\n", "\n", "* Fit a linear regression model to the data from the taxis dataset, using multiple input variables (also called features, also called predictors), and with \"total\" as the output variable (the target)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cell_id": "976eb0e3e3d5497e82d41d358523aac5", "deepnote_cell_height": 395, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 31, "execution_start": 1651867285845, "source_hash": "c085b6ba", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pickupdropoffpassengersdistancefaretiptollstotalcolorpaymentpickup_zonedropoff_zonepickup_boroughdropoff_borough
02019-03-23 20:21:092019-03-23 20:27:2411.607.02.150.012.95yellowcredit cardLenox Hill WestUN/Turtle Bay SouthManhattanManhattan
12019-03-04 16:11:552019-03-04 16:19:0010.795.00.000.09.30yellowcashUpper West Side SouthUpper West Side SouthManhattanManhattan
22019-03-27 17:53:012019-03-27 18:00:2511.377.52.360.014.16yellowcredit cardAlphabet CityWest VillageManhattanManhattan
32019-03-10 01:23:592019-03-10 01:49:5117.7027.06.150.036.95yellowcredit cardHudson SqYorkville WestManhattanManhattan
42019-03-30 13:27:422019-03-30 13:37:1432.169.01.100.013.40yellowcredit cardMidtown EastYorkville WestManhattanManhattan
\n", "
" ], "text/plain": [ " pickup dropoff passengers distance fare tip \\\n", "0 2019-03-23 20:21:09 2019-03-23 20:27:24 1 1.60 7.0 2.15 \n", "1 2019-03-04 16:11:55 2019-03-04 16:19:00 1 0.79 5.0 0.00 \n", "2 2019-03-27 17:53:01 2019-03-27 18:00:25 1 1.37 7.5 2.36 \n", "3 2019-03-10 01:23:59 2019-03-10 01:49:51 1 7.70 27.0 6.15 \n", "4 2019-03-30 13:27:42 2019-03-30 13:37:14 3 2.16 9.0 1.10 \n", "\n", " tolls total color payment pickup_zone \\\n", "0 0.0 12.95 yellow credit card Lenox Hill West \n", "1 0.0 9.30 yellow cash Upper West Side South \n", "2 0.0 14.16 yellow credit card Alphabet City \n", "3 0.0 36.95 yellow credit card Hudson Sq \n", "4 0.0 13.40 yellow credit card Midtown East \n", "\n", " dropoff_zone pickup_borough dropoff_borough \n", "0 UN/Turtle Bay South Manhattan Manhattan \n", "1 Upper West Side South Manhattan Manhattan \n", "2 West Village Manhattan Manhattan \n", "3 Yorkville West Manhattan Manhattan \n", "4 Yorkville West Manhattan Manhattan " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What are the rows with the biggest values in the \"tolls\" column?" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cell_id": "8688e54e02b0424d87cb742b65c003e4", "deepnote_cell_height": 600, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 101, "execution_start": 1651867460008, "source_hash": "7b3c1175", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pickupdropoffpassengersdistancefaretiptollstotalcolorpaymentpickup_zonedropoff_zonepickup_boroughdropoff_borough
53642019-03-17 16:59:172019-03-17 18:04:08236.70150.000.0024.02174.82yellowcashJFK AirportJFK AirportQueensQueens
21222019-03-08 00:40:322019-03-08 01:11:53115.5144.0016.2717.2881.35yellowcredit cardTriBeCa/Civic CenterWest BrightonManhattanStaten Island
36402019-03-22 07:54:092019-03-22 09:05:13116.4252.000.0012.5067.80yellowcashJFK AirportMurray HillQueensManhattan
59112019-03-09 12:27:512019-03-09 13:11:18111.4039.000.0011.5251.32greencredit cardWindsor TerraceClinton EastBrooklynManhattan
57282019-03-01 17:07:092019-03-01 18:05:41121.2765.590.0011.5277.61greencredit cardCambria HeightsMorningside HeightsQueensManhattan
.............................................
22032019-03-03 00:24:502019-03-03 00:56:1714.7222.005.160.0030.96yellowcredit cardMeatpacking/West Village WestWilliamsburg (South Side)ManhattanBrooklyn
22022019-03-14 22:32:332019-03-14 22:49:3912.9013.003.360.0020.16yellowcredit cardEast ChelseaEast VillageManhattanManhattan
22012019-03-18 21:16:422019-03-18 21:27:4913.0011.503.060.0018.36yellowcredit cardClinton EastUpper East Side NorthManhattanManhattan
22002019-03-03 07:21:402019-03-03 07:39:1217.2022.007.550.0032.85yellowcredit cardMidtown CenterWorld Trade CenterManhattanManhattan
64322019-03-13 19:31:222019-03-13 19:48:0213.8515.003.360.0020.16greencredit cardBoerum HillWindsor TerraceBrooklynBrooklyn
\n", "

6341 rows × 14 columns

\n", "
" ], "text/plain": [ " pickup dropoff passengers distance fare \\\n", "5364 2019-03-17 16:59:17 2019-03-17 18:04:08 2 36.70 150.00 \n", "2122 2019-03-08 00:40:32 2019-03-08 01:11:53 1 15.51 44.00 \n", "3640 2019-03-22 07:54:09 2019-03-22 09:05:13 1 16.42 52.00 \n", "5911 2019-03-09 12:27:51 2019-03-09 13:11:18 1 11.40 39.00 \n", "5728 2019-03-01 17:07:09 2019-03-01 18:05:41 1 21.27 65.59 \n", "... ... ... ... ... ... \n", "2203 2019-03-03 00:24:50 2019-03-03 00:56:17 1 4.72 22.00 \n", "2202 2019-03-14 22:32:33 2019-03-14 22:49:39 1 2.90 13.00 \n", "2201 2019-03-18 21:16:42 2019-03-18 21:27:49 1 3.00 11.50 \n", "2200 2019-03-03 07:21:40 2019-03-03 07:39:12 1 7.20 22.00 \n", "6432 2019-03-13 19:31:22 2019-03-13 19:48:02 1 3.85 15.00 \n", "\n", " tip tolls total color payment \\\n", "5364 0.00 24.02 174.82 yellow cash \n", "2122 16.27 17.28 81.35 yellow credit card \n", "3640 0.00 12.50 67.80 yellow cash \n", "5911 0.00 11.52 51.32 green credit card \n", "5728 0.00 11.52 77.61 green credit card \n", "... ... ... ... ... ... \n", "2203 5.16 0.00 30.96 yellow credit card \n", "2202 3.36 0.00 20.16 yellow credit card \n", "2201 3.06 0.00 18.36 yellow credit card \n", "2200 7.55 0.00 32.85 yellow credit card \n", "6432 3.36 0.00 20.16 green credit card \n", "\n", " pickup_zone dropoff_zone pickup_borough \\\n", "5364 JFK Airport JFK Airport Queens \n", "2122 TriBeCa/Civic Center West Brighton Manhattan \n", "3640 JFK Airport Murray Hill Queens \n", "5911 Windsor Terrace Clinton East Brooklyn \n", "5728 Cambria Heights Morningside Heights Queens \n", "... ... ... ... \n", "2203 Meatpacking/West Village West Williamsburg (South Side) Manhattan \n", "2202 East Chelsea East Village Manhattan \n", "2201 Clinton East Upper East Side North Manhattan \n", "2200 Midtown Center World Trade Center Manhattan \n", "6432 Boerum Hill Windsor Terrace Brooklyn \n", "\n", " dropoff_borough \n", "5364 Queens \n", "2122 Staten Island \n", "3640 Manhattan \n", "5911 Manhattan \n", "5728 Manhattan \n", "... ... \n", "2203 Brooklyn \n", "2202 Manhattan \n", "2201 Manhattan \n", "2200 Manhattan \n", "6432 Brooklyn \n", "\n", "[6341 rows x 14 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.sort_values(\"tolls\", ascending=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try to use the following columns as the inputs for our linear regression." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "cell_id": "5d200510d6cd473993055fac958d6d84", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 0, "execution_start": 1651867531843, "source_hash": "4f299aee", "tags": [] }, "outputs": [], "source": [ "cols = [\"distance\", \"tip\", \"tolls\", \"pickup_borough\"]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "cell_id": "6036759038e14127a3afc5ec29642a02", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_start": 1651867578382, "source_hash": "9527aab5", "tags": [] }, "outputs": [], "source": [ "from sklearn.linear_model import LinearRegression" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "cell_id": "dd4831013cb34e028b80b06e30264260", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 2, "execution_start": 1651867602454, "source_hash": "cd72858e", "tags": [] }, "outputs": [], "source": [ "reg = LinearRegression()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This doesn't work, because the values in the \"pickup_borough\" column are strings, not numbers." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "cell_id": "ae2578df1ea04c38965807b8338abbde", "deepnote_cell_height": 144.1875, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 645, "execution_start": 1651867712330, "source_hash": "623c24f0", "tags": [ "output_scroll" ] }, "outputs": [ { "ename": "ValueError", "evalue": "could not convert string to float: 'Manhattan'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/var/folders/8j/gshrlmtn7dg4qtztj4d4t_w40000gn/T/ipykernel_15733/3118760667.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mreg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcols\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"total\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/miniconda3/envs/math10s22/lib/python3.7/site-packages/sklearn/linear_model/_base.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 661\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 662\u001b[0m X, y = self._validate_data(\n\u001b[0;32m--> 663\u001b[0;31m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccept_sparse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maccept_sparse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_numeric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmulti_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 664\u001b[0m )\n\u001b[1;32m 665\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/math10s22/lib/python3.7/site-packages/sklearn/base.py\u001b[0m in \u001b[0;36m_validate_data\u001b[0;34m(self, X, y, reset, validate_separately, **check_params)\u001b[0m\n\u001b[1;32m 579\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mcheck_y_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 580\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 581\u001b[0;31m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_X_y\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mcheck_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 582\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 583\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/math10s22/lib/python3.7/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_X_y\u001b[0;34m(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)\u001b[0m\n\u001b[1;32m 974\u001b[0m \u001b[0mensure_min_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mensure_min_samples\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 975\u001b[0m \u001b[0mensure_min_features\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mensure_min_features\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 976\u001b[0;31m \u001b[0mestimator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 977\u001b[0m )\n\u001b[1;32m 978\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/math10s22/lib/python3.7/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_array\u001b[0;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)\u001b[0m\n\u001b[1;32m 744\u001b[0m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcasting\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"unsafe\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 745\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 746\u001b[0;31m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 747\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mComplexWarning\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcomplex_warning\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 748\u001b[0m raise ValueError(\n", "\u001b[0;32m~/miniconda3/envs/math10s22/lib/python3.7/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m__array__\u001b[0;34m(self, dtype)\u001b[0m\n\u001b[1;32m 1991\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1992\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__array__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mNpDtype\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1993\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1994\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1995\u001b[0m def __array_wrap__(\n", "\u001b[0;31mValueError\u001b[0m: could not convert string to float: 'Manhattan'" ] } ], "source": [ "reg.fit(df[cols],df[\"total\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's make a new column called \"Manhattan\". This will contain `1` for the \"Manhattan\" pickup borough rows, and contain `0` for all the other rows." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "cell_id": "4f3fe280d62344b1aace68032c9bfaba", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 4, "execution_start": 1651867826643, "source_hash": "abd1c605", "tags": [] }, "outputs": [], "source": [ "df[\"Manhattan\"] = 0" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "cell_id": "4f8410183e764b71ac30ecd43bce7cd7", "deepnote_cell_height": 329.328125, "deepnote_cell_type": "code", "deepnote_output_heights": [ 232.34375 ], "deepnote_to_be_reexecuted": false, "execution_millis": 4, "execution_start": 1651867882020, "source_hash": "716201ae", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "0 True\n", "1 True\n", "2 True\n", "3 True\n", "4 True\n", " ... \n", "6428 True\n", "6429 False\n", "6430 False\n", "6431 False\n", "6432 False\n", "Name: pickup_borough, Length: 6341, dtype: bool" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[\"pickup_borough\"] == \"Manhattan\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(I think we could also store Boolean values directly in this new \"Manhattan\" column, but I think it's less confusing to have `0` and `1`.)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "cell_id": "4d049b7bea014e5aa1d3f39e922cd447", "deepnote_cell_height": 99, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 4, "execution_start": 1651867950135, "source_hash": "8abe422d", "tags": [] }, "outputs": [], "source": [ "# Put a 1 (for True) where the value is Manhattan\n", "df.loc[df[\"pickup_borough\"] == \"Manhattan\", \"Manhattan\"] = 1" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "cell_id": "40e9bffea7c1449a97e3b221edd3583c", "deepnote_cell_height": 600, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 173, "execution_start": 1651867954447, "source_hash": "f804c160", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pickupdropoffpassengersdistancefaretiptollstotalcolorpaymentpickup_zonedropoff_zonepickup_boroughdropoff_boroughManhattan
02019-03-23 20:21:092019-03-23 20:27:2411.607.02.150.012.95yellowcredit cardLenox Hill WestUN/Turtle Bay SouthManhattanManhattan1
12019-03-04 16:11:552019-03-04 16:19:0010.795.00.000.09.30yellowcashUpper West Side SouthUpper West Side SouthManhattanManhattan1
22019-03-27 17:53:012019-03-27 18:00:2511.377.52.360.014.16yellowcredit cardAlphabet CityWest VillageManhattanManhattan1
32019-03-10 01:23:592019-03-10 01:49:5117.7027.06.150.036.95yellowcredit cardHudson SqYorkville WestManhattanManhattan1
42019-03-30 13:27:422019-03-30 13:37:1432.169.01.100.013.40yellowcredit cardMidtown EastYorkville WestManhattanManhattan1
................................................
64282019-03-31 09:51:532019-03-31 09:55:2710.754.51.060.06.36greencredit cardEast Harlem NorthCentral Harlem NorthManhattanManhattan1
64292019-03-31 17:38:002019-03-31 18:34:23118.7458.00.000.058.80greencredit cardJamaicaEast Concourse/Concourse VillageQueensBronx0
64302019-03-23 22:55:182019-03-23 23:14:2514.1416.00.000.017.30greencashCrown Heights NorthBushwick NorthBrooklynBrooklyn0
64312019-03-04 10:09:252019-03-04 10:14:2911.126.00.000.06.80greencredit cardEast New YorkEast Flatbush/Remsen VillageBrooklynBrooklyn0
64322019-03-13 19:31:222019-03-13 19:48:0213.8515.03.360.020.16greencredit cardBoerum HillWindsor TerraceBrooklynBrooklyn0
\n", "

6341 rows × 15 columns

\n", "
" ], "text/plain": [ " pickup dropoff passengers distance fare \\\n", "0 2019-03-23 20:21:09 2019-03-23 20:27:24 1 1.60 7.0 \n", "1 2019-03-04 16:11:55 2019-03-04 16:19:00 1 0.79 5.0 \n", "2 2019-03-27 17:53:01 2019-03-27 18:00:25 1 1.37 7.5 \n", "3 2019-03-10 01:23:59 2019-03-10 01:49:51 1 7.70 27.0 \n", "4 2019-03-30 13:27:42 2019-03-30 13:37:14 3 2.16 9.0 \n", "... ... ... ... ... ... \n", "6428 2019-03-31 09:51:53 2019-03-31 09:55:27 1 0.75 4.5 \n", "6429 2019-03-31 17:38:00 2019-03-31 18:34:23 1 18.74 58.0 \n", "6430 2019-03-23 22:55:18 2019-03-23 23:14:25 1 4.14 16.0 \n", "6431 2019-03-04 10:09:25 2019-03-04 10:14:29 1 1.12 6.0 \n", "6432 2019-03-13 19:31:22 2019-03-13 19:48:02 1 3.85 15.0 \n", "\n", " tip tolls total color payment pickup_zone \\\n", "0 2.15 0.0 12.95 yellow credit card Lenox Hill West \n", "1 0.00 0.0 9.30 yellow cash Upper West Side South \n", "2 2.36 0.0 14.16 yellow credit card Alphabet City \n", "3 6.15 0.0 36.95 yellow credit card Hudson Sq \n", "4 1.10 0.0 13.40 yellow credit card Midtown East \n", "... ... ... ... ... ... ... \n", "6428 1.06 0.0 6.36 green credit card East Harlem North \n", "6429 0.00 0.0 58.80 green credit card Jamaica \n", "6430 0.00 0.0 17.30 green cash Crown Heights North \n", "6431 0.00 0.0 6.80 green credit card East New York \n", "6432 3.36 0.0 20.16 green credit card Boerum Hill \n", "\n", " dropoff_zone pickup_borough dropoff_borough \\\n", "0 UN/Turtle Bay South Manhattan Manhattan \n", "1 Upper West Side South Manhattan Manhattan \n", "2 West Village Manhattan Manhattan \n", "3 Yorkville West Manhattan Manhattan \n", "4 Yorkville West Manhattan Manhattan \n", "... ... ... ... \n", "6428 Central Harlem North Manhattan Manhattan \n", "6429 East Concourse/Concourse Village Queens Bronx \n", "6430 Bushwick North Brooklyn Brooklyn \n", "6431 East Flatbush/Remsen Village Brooklyn Brooklyn \n", "6432 Windsor Terrace Brooklyn Brooklyn \n", "\n", " Manhattan \n", "0 1 \n", "1 1 \n", "2 1 \n", "3 1 \n", "4 1 \n", "... ... \n", "6428 1 \n", "6429 0 \n", "6430 0 \n", "6431 0 \n", "6432 0 \n", "\n", "[6341 rows x 15 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now replace the old \"pickup_borough\" column with the newly created \"Manhattan\" column." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "cell_id": "a49045990af24d4c81824ab7ddc663c7", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 0, "execution_start": 1651868159514, "source_hash": "eb320fc1", "tags": [] }, "outputs": [], "source": [ "cols = ['distance', 'tip', 'tolls', 'Manhattan']" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "cell_id": "b457cc6309874598b39086ba14bf2265", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 17, "execution_start": 1651868183429, "source_hash": "623c24f0", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "LinearRegression()" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg.fit(df[cols],df[\"total\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The goal of the `fit` method is to find the following coefficients, as well as the intercept." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "cell_id": "84d9871e471e464abd94766d4a1ce6fe", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 9, "execution_start": 1651868204225, "source_hash": "c0bd65f9", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "array([2.6294669 , 1.3306588 , 1.11487961, 1.55477579])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg.coef_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should interpret the following as saying that the total cost of the taxi ride will be modeled by a formula involving 2.63 times the distance traveled. This can be interpreted as $2.63 per mile." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "cell_id": "0c6bbb2f0f84437786d04013365833c6", "deepnote_cell_height": 194.953125, "deepnote_cell_type": "code", "deepnote_output_heights": [ 97.96875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 11, "execution_start": 1651868224851, "source_hash": "77d2cb3a", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "distance 2.629467\n", "tip 1.330659\n", "tolls 1.114880\n", "Manhattan 1.554776\n", "dtype: float64" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.Series(reg.coef_, index=cols)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "cell_id": "3638c804896b4ed28087b03c47ca16b6", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 9, "execution_start": 1651868437350, "source_hash": "1661cbe7", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "6.170557177757704" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg.intercept_" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "cell_id": "bf636924d7214855bbd7ab4eecb6bcbd", "deepnote_cell_height": 272, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 102, "execution_start": 1651868520714, "source_hash": "610ca1b2", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pickupdropoffpassengersdistancefaretiptollstotalcolorpaymentpickup_zonedropoff_zonepickup_boroughdropoff_boroughManhattan
02019-03-23 20:21:092019-03-23 20:27:2411.607.02.150.012.95yellowcredit cardLenox Hill WestUN/Turtle Bay SouthManhattanManhattan1
12019-03-04 16:11:552019-03-04 16:19:0010.795.00.000.09.30yellowcashUpper West Side SouthUpper West Side SouthManhattanManhattan1
\n", "
" ], "text/plain": [ " pickup dropoff passengers distance fare tip \\\n", "0 2019-03-23 20:21:09 2019-03-23 20:27:24 1 1.60 7.0 2.15 \n", "1 2019-03-04 16:11:55 2019-03-04 16:19:00 1 0.79 5.0 0.00 \n", "\n", " tolls total color payment pickup_zone \\\n", "0 0.0 12.95 yellow credit card Lenox Hill West \n", "1 0.0 9.30 yellow cash Upper West Side South \n", "\n", " dropoff_zone pickup_borough dropoff_borough Manhattan \n", "0 UN/Turtle Bay South Manhattan Manhattan 1 \n", "1 Upper West Side South Manhattan Manhattan 1 " ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[:2]" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "cell_id": "7b17c01060a14ecab3ce50f8caf56c3a", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 7, "execution_start": 1651868561204, "source_hash": "dea74e1c", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "array([14.79339642, 9.80261182])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg.predict(df[:2][cols])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example, we can view `14.8` as the predicted output for the 0th row. The `predict` method isn't doing anything mysterious. It's just evaluating this linear function on the given inputs. Here is the by-hand computation for the 0th row." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "14.787500000000001" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "2.63*1.6+1.33*2.15+1.11*0+1.55*1+6.17" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "e6a10e193f07477694fbe10b4c9276e5", "deepnote_cell_height": 189.5625, "deepnote_cell_type": "markdown", "tags": [] }, "source": [ "## Polynomial regression\n", "\n", "Last time, we fit a degree 9 polynomial model to this data, using \"distance\" as the (only) input variable and using \"total\" as the output variable. The code from last time is below.\n", "\n", "Using 100 training points, adapt the code from last time to fit models of different degrees, for each degree from 1 to 25. Plot the resulting polynomials for $0 \\leq x \\leq M$, where $M$ is the maximum \"distance\" value within the training data.\n", "\n", "A lot of this code was copied from last time, and then adjusted to the current goals." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "cell_id": "2235c85242b94664b25889188080e731", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 0, "execution_start": 1651868813242, "source_hash": "746a4dbc", "tags": [] }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We're using 100 data points, instead of 40 from last time, so we should expect that there will be slightly less overfitting this time, since we are using more data points." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "cell_id": "302431b4d1b9448e9756128e893b371c", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 10, "execution_start": 1651868851207, "source_hash": "e5cb177c", "tags": [] }, "outputs": [], "source": [ "df_train, df_test = train_test_split(df, train_size=100)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "cell_id": "f2b50781bde34eb8835d5af09977b66f", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 10, "execution_start": 1651868856995, "source_hash": "db981d08", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(100, 15)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.shape" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "cell_id": "5a8731f236144a1788b72601b501d36f", "deepnote_cell_height": 135, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 0, "execution_start": 1651868879671, "source_hash": "1fb53adb", "tags": [] }, "outputs": [], "source": [ "c = alt.Chart(df_train).mark_circle().encode(\n", " x=\"distance\",\n", " y=\"total\"\n", ")" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "cell_id": "6dfedd23462546f09646ff7c68163dbc", "deepnote_cell_height": 458, "deepnote_cell_type": "code", "deepnote_output_heights": [ 361 ], "deepnote_to_be_reexecuted": false, "execution_millis": 722, "execution_start": 1651868881576, "source_hash": "957caa7e", "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "cell_id": "e2597158c406468b8a73184b3514f60c", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 18, "execution_start": 1651869010986, "source_hash": "2db08a7e", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "26.92" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train[\"distance\"].max()" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "cell_id": "60475d07b67d4e8c949dec2eac9a1097", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 1, "execution_start": 1651869123315, "source_hash": "c2602aa8", "tags": [] }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "cell_id": "371c9e9cfae54cb0b669921e753c069f", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 13, "execution_start": 1651869125772, "source_hash": "feda1de0", "tags": [] }, "outputs": [], "source": [ "df_plot = pd.DataFrame({\"distance\":np.arange(0,df_train[\"distance\"].max()+0.1,0.1)})" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "cell_id": "d1a4200eee8c40f6957efd54e0274819", "deepnote_cell_height": 395, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 1, "execution_start": 1651869265812, "source_hash": "58190bf0", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
distance
00.0
10.1
20.2
30.3
40.4
\n", "
" ], "text/plain": [ " distance\n", "0 0.0\n", "1 0.1\n", "2 0.2\n", "3 0.3\n", "4 0.4" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_plot.head()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "cell_id": "2967701b35294fb09741c0df51ea28d0", "deepnote_cell_height": 171, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 24, "execution_start": 1651869560543, "source_hash": "c34b5e23", "tags": [] }, "outputs": [], "source": [ "cols = []\n", "for deg in range(1,25):\n", " col = f\"d{deg}\"\n", " cols.append(col)\n", " for x in [df_train, df_plot]:\n", " x[col] = x[\"distance\"]**deg" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "cell_id": "df16c1fb887e4412a443185b010a9a0d", "deepnote_cell_height": 395, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 39, "execution_start": 1651869335836, "source_hash": "58190bf0", "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
distanced1d2d3d4d5d6d7d8d9...d15d16d17d18d19d20d21d22d23d24
00.00.00.000.0000.00000.000000.0000000.000000e+000.000000e+000.000000e+00...0.000000e+000.000000e+000.000000e+000.000000e+000.000000e+000.000000e+000.000000e+000.000000e+000.000000e+000.000000e+00
10.10.10.010.0010.00010.000010.0000011.000000e-071.000000e-081.000000e-09...1.000000e-151.000000e-161.000000e-171.000000e-181.000000e-191.000000e-201.000000e-211.000000e-221.000000e-231.000000e-24
20.20.20.040.0080.00160.000320.0000641.280000e-052.560000e-065.120000e-07...3.276800e-116.553600e-121.310720e-122.621440e-135.242880e-141.048576e-142.097152e-154.194304e-168.388608e-171.677722e-17
30.30.30.090.0270.00810.002430.0007292.187000e-046.561000e-051.968300e-05...1.434891e-084.304672e-091.291402e-093.874205e-101.162261e-103.486784e-111.046035e-113.138106e-129.414318e-132.824295e-13
40.40.40.160.0640.02560.010240.0040961.638400e-036.553600e-042.621440e-04...1.073742e-064.294967e-071.717987e-076.871948e-082.748779e-081.099512e-084.398047e-091.759219e-097.036874e-102.814750e-10
\n", "

5 rows × 25 columns

\n", "
" ], "text/plain": [ " distance d1 d2 d3 d4 d5 d6 d7 \\\n", "0 0.0 0.0 0.00 0.000 0.0000 0.00000 0.000000 0.000000e+00 \n", "1 0.1 0.1 0.01 0.001 0.0001 0.00001 0.000001 1.000000e-07 \n", "2 0.2 0.2 0.04 0.008 0.0016 0.00032 0.000064 1.280000e-05 \n", "3 0.3 0.3 0.09 0.027 0.0081 0.00243 0.000729 2.187000e-04 \n", "4 0.4 0.4 0.16 0.064 0.0256 0.01024 0.004096 1.638400e-03 \n", "\n", " d8 d9 ... d15 d16 d17 \\\n", "0 0.000000e+00 0.000000e+00 ... 0.000000e+00 0.000000e+00 0.000000e+00 \n", "1 1.000000e-08 1.000000e-09 ... 1.000000e-15 1.000000e-16 1.000000e-17 \n", "2 2.560000e-06 5.120000e-07 ... 3.276800e-11 6.553600e-12 1.310720e-12 \n", "3 6.561000e-05 1.968300e-05 ... 1.434891e-08 4.304672e-09 1.291402e-09 \n", "4 6.553600e-04 2.621440e-04 ... 1.073742e-06 4.294967e-07 1.717987e-07 \n", "\n", " d18 d19 d20 d21 d22 \\\n", "0 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 \n", "1 1.000000e-18 1.000000e-19 1.000000e-20 1.000000e-21 1.000000e-22 \n", "2 2.621440e-13 5.242880e-14 1.048576e-14 2.097152e-15 4.194304e-16 \n", "3 3.874205e-10 1.162261e-10 3.486784e-11 1.046035e-11 3.138106e-12 \n", "4 6.871948e-08 2.748779e-08 1.099512e-08 4.398047e-09 1.759219e-09 \n", "\n", " d23 d24 \n", "0 0.000000e+00 0.000000e+00 \n", "1 1.000000e-23 1.000000e-24 \n", "2 8.388608e-17 1.677722e-17 \n", "3 9.414318e-13 2.824295e-13 \n", "4 7.036874e-10 2.814750e-10 \n", "\n", "[5 rows x 25 columns]" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_plot.head()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "cell_id": "29226b21f26a414cb83bdbfe73b31d37", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 16, "execution_start": 1651869388498, "source_hash": "cb73c151", "tags": [ "output_scroll" ] }, "outputs": [ { "data": { "text/plain": [ "['d1',\n", " 'd2',\n", " 'd3',\n", " 'd4',\n", " 'd5',\n", " 'd6',\n", " 'd7',\n", " 'd8',\n", " 'd9',\n", " 'd10',\n", " 'd11',\n", " 'd12',\n", " 'd13',\n", " 'd14',\n", " 'd15',\n", " 'd16',\n", " 'd17',\n", " 'd18',\n", " 'd19',\n", " 'd20',\n", " 'd21',\n", " 'd22',\n", " 'd23',\n", " 'd24']" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "cell_id": "cf6545c1c24342d2b155a6da0170240e", "deepnote_cell_height": 118.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 21.1875 ], "deepnote_to_be_reexecuted": false, "execution_millis": 4, "execution_start": 1651869644009, "source_hash": "4dd5a516", "tags": [] }, "outputs": [ { "data": { "text/plain": [ "['d1', 'd2', 'd3', 'd4']" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols[:4]" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "cell_id": "4e986530fdb645e0bf924b36c1d69302", "deepnote_cell_height": 279, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 311, "execution_start": 1651869933486, "source_hash": "6a82040", "tags": [] }, "outputs": [], "source": [ "chart_list = []\n", "\n", "for deg in range(1,25):\n", " subcols = cols[:deg]\n", " reg = LinearRegression()\n", " reg.fit(df_train[subcols],df_train[\"total\"])\n", " df_plot[f\"Pred{deg}\"] = reg.predict(df_plot[subcols])\n", " c_temp = alt.Chart(df_plot).mark_line(color=\"red\", clip=True).encode(\n", " x=\"distance\",\n", " y=alt.Y(f\"Pred{deg}\", scale=alt.Scale(domain=(0,200)))\n", " )\n", " chart_list.append(c_temp)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "cell_id": "1b85bb3f98c54cc49cb81044591236a6", "deepnote_cell_height": 81, "deepnote_cell_type": "code", "deepnote_output_heights": [ 361 ], "deepnote_to_be_reexecuted": false, "execution_millis": 0, "execution_start": 1651870095040, "source_hash": "9e4c707b", "tags": [] }, "outputs": [], "source": [ "both_charts = [c+d for d in chart_list]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The input to `alt.vconcat` needs to be one or more Altair charts, not a list of Altair charts." ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "cell_id": "c0c31619c7944a21a1c4b3b33fd74a81", "deepnote_cell_height": 144.1875, "deepnote_cell_type": "code", "deepnote_output_heights": [ 361 ], "deepnote_to_be_reexecuted": false, "execution_millis": 23, "execution_start": 1651870138821, "source_hash": "8ec9aa9", "tags": [ "output_scroll" ] }, "outputs": [ { "ename": "ValueError", "evalue": "Only chart objects can be used in VConcatChart.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/var/folders/8j/gshrlmtn7dg4qtztj4d4t_w40000gn/T/ipykernel_15733/187346778.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0malt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mboth_charts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/miniconda3/envs/math10s22/lib/python3.7/site-packages/altair/vegalite/v4/api.py\u001b[0m in \u001b[0;36mvconcat\u001b[0;34m(*charts, **kwargs)\u001b[0m\n\u001b[1;32m 2330\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mcharts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2331\u001b[0m \u001b[0;34m\"\"\"Concatenate charts vertically\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2332\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mVConcatChart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvconcat\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcharts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2333\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2334\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/math10s22/lib/python3.7/site-packages/altair/vegalite/v4/api.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, data, vconcat, **kwargs)\u001b[0m\n\u001b[1;32m 2304\u001b[0m \u001b[0;31m# TODO: move common data to top level?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2305\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mspec\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mvconcat\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2306\u001b[0;31m \u001b[0m_check_if_valid_subspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"VConcatChart\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2307\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mVConcatChart\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvconcat\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvconcat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2308\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvconcat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_combine_subchart_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvconcat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/math10s22/lib/python3.7/site-packages/altair/vegalite/v4/api.py\u001b[0m in \u001b[0;36m_check_if_valid_subspec\u001b[0;34m(spec, classname)\u001b[0m\n\u001b[1;32m 2072\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2073\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSchemaBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2074\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Only chart objects can be used in {0}.\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2075\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mattr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mTOPLEVEL_ONLY_KEYS\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2076\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSchemaBase\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: Only chart objects can be used in VConcatChart." ] } ], "source": [ "alt.vconcat(both_charts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we use list unpacking. Notice how the overfitting gets more extreme as the degree of the polynomial gets higher." ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "cell_id": "026a34dcad3a4a5ca0a824734b7c464e", "deepnote_cell_height": 66, "deepnote_cell_type": "code", "deepnote_to_be_reexecuted": false, "execution_millis": 5647, "execution_start": 1651870154472, "source_hash": "6ddab992", "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.VConcatChart(...)" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "alt.vconcat(*both_charts)" ] } ], "metadata": { "deepnote": {}, "deepnote_execution_queue": [], "deepnote_notebook_id": "a77823bb-2b5b-436d-9a2e-689e1064d274", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 4 }