Wine Quality#
Author: Paula Mendez-Lagunas
Course Project, UC Irvine, Math 10, S23
Introduction#
The dataset my project focuses on is about wine and it contains columns that describe chemical properties related to each wine and a column which assigns it a quality score. I decided to use classification machine learning models on the data after refining the original dataframe. The two models I chose to use are DecicionTreeClassifier and KNeighborsClassifier.
Importing#
import pandas as pd
import altair as alt
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
Section 1: Preparing the Data#
In this section I use pandas to gain information about the original dataframe and create a new column named Class indicating whether the wine is good or bad based on the quality score it received. Finally I create a new dataframe with balanced values of good and bad wines.
df_temp = pd.read_csv("winequality-red.csv")
# Checking if there are any columns with missing values; notice all columns are numerical
df_temp.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1599 entries, 0 to 1598
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 fixed acidity 1599 non-null float64
1 volatile acidity 1599 non-null float64
2 citric acid 1599 non-null float64
3 residual sugar 1599 non-null float64
4 chlorides 1599 non-null float64
5 free sulfur dioxide 1599 non-null float64
6 total sulfur dioxide 1599 non-null float64
7 density 1599 non-null float64
8 pH 1599 non-null float64
9 sulphates 1599 non-null float64
10 alcohol 1599 non-null float64
11 quality 1599 non-null int64
dtypes: float64(11), int64(1)
memory usage: 150.0 KB
I actually want this to be a classification problem so I want to add a new column that indicates whether the wine’s quality is good or bad. I do this using a lambda function and map.
df_temp["quality"].max()
8
df_temp["class"] = df_temp["quality"].map(lambda x: "good" if x>6 else "bad")
# This helps visualize the proportion of good wine to bad wine in the original dataframe
alt.Chart(df_temp).mark_bar().encode(
x = "class",
y = "count()",
tooltip = ["count()"]
)
In order to create a more balanced DataFrame I will get 250 random rows whose class is bad and then concatenate it to a dataframe that has all the rows whose class is good.
# I used a random state to get reproducible results
df_good = df_temp[df_temp["class"] == "good"]
df_bad = df_temp[df_temp["class"] == "bad"].sample(250, random_state= 97)
# I use axis = 0 since I want to join them along their rows
df = pd.concat((df_good, df_bad), axis= 0)
df.shape
(467, 13)
Therefore my final dataframe is named df and we can see that it now contains 467 rows (of which 250 are labeled “bad” wine) and 13 columns (the 12 original and one we added named “class”).
Section 2: Visualizing Data Relations#
In this section I mainly use altair to create charts for specific data relations that I want to see. Also in order to avoid rewriting the same code I write a fuction called ‘make_chart’.
# This helps me understand the distribution of the data in each column
df.describe()
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | quality | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 467.000000 | 467.000000 | 467.000000 | 467.000000 | 467.000000 | 467.000000 | 467.000000 | 467.000000 | 467.000000 | 467.000000 | 467.000000 | 467.000000 |
mean | 8.438972 | 0.486895 | 0.297559 | 2.588544 | 0.083615 | 14.706638 | 41.557816 | 0.996456 | 3.307323 | 0.686574 | 10.852819 | 6.179872 |
std | 1.778647 | 0.177284 | 0.203943 | 1.415962 | 0.042843 | 9.873047 | 32.468504 | 0.002044 | 0.147025 | 0.158632 | 1.171713 | 0.977088 |
min | 4.600000 | 0.120000 | 0.000000 | 0.900000 | 0.012000 | 1.000000 | 7.000000 | 0.990070 | 2.880000 | 0.370000 | 8.700000 | 3.000000 |
25% | 7.200000 | 0.350000 | 0.100000 | 1.900000 | 0.066000 | 7.000000 | 20.000000 | 0.995160 | 3.210000 | 0.570000 | 9.800000 | 5.000000 |
50% | 8.100000 | 0.460000 | 0.320000 | 2.200000 | 0.077000 | 12.000000 | 31.000000 | 0.996430 | 3.300000 | 0.670000 | 10.800000 | 6.000000 |
75% | 9.550000 | 0.605000 | 0.450000 | 2.600000 | 0.089000 | 19.000000 | 52.000000 | 0.997700 | 3.390000 | 0.770000 | 11.700000 | 7.000000 |
max | 15.600000 | 1.330000 | 0.790000 | 15.400000 | 0.467000 | 55.000000 | 289.000000 | 1.003690 | 3.900000 | 1.560000 | 14.000000 | 8.000000 |
From the information above I can see that some columns have larger scales (for example total sulfur dioxide) while some columns have smaller scales (such as density). Also, by comparing the mean, standard deviation, and max value of each column it seems that most columns have outliers.
#The following code is an adaptation of a code that was used in Worksheet 7
def make_chart(col):
return alt.Chart(df).mark_circle().encode(
x= alt.X("quality", scale= alt.Scale(zero= False)),
y= alt.Y(col, scale= alt.Scale(zero= False)),
color= "class",
)
Using the code above, I want to make a chart comparing each input feature column to the quality. Furthermore, since from above we know that each column has a different scale I chose to include ‘zero= False’ in the code so that each chart could have its own scale.
#Only want the first 11 because those are input features and the other 2 cols are output features
cols = [col for col in df.columns[:11]]
chart_list = [make_chart(col) for col in cols]
# This code was also taken from Worksheet 7
total_chart = alt.vconcat(*chart_list)
total_chart
Based on the charts above, almost every single one has outlier(s) although it is most noticeable for residual sugar, chlorides, and total sulfure dioxide. Furthermore we can see that only using one input feature or column would be hard to help classify the wine since there are no clear patterns in these charts.
Section 3: Machine Learning Models#
In this final section I apply the DecisionTreeClassifier and the KNeighborsClassifier models and compare their accuracy. I also use train_test_split to test the DecisionTreeClassifier model for overfitting. For the KNeighborsClassifier model I create a confusion matrix to see its prediction results.
# Instantiate; I decided to use 15 max leaf nodes because there are 11 input variables and 15>11
clf = DecisionTreeClassifier(max_leaf_nodes= 15, random_state= 126)
#Fit; I am using all 11 feature columns and want to predict whether the wine is "good" or "bad"
clf.fit(df[cols], df["class"])
DecisionTreeClassifier(max_leaf_nodes=15, random_state=126)
# This reveals that the most influencing feature is Alcohol
# The 3 columns I said had most noticeable outliers are at the bottom of this chart, I wonder why
pd.Series(clf.feature_importances_ , clf.feature_names_in_).sort_values(ascending= False)
alcohol 0.507042
volatile acidity 0.171973
sulphates 0.124870
fixed acidity 0.065348
citric acid 0.034490
total sulfur dioxide 0.031346
density 0.022035
pH 0.021488
chlorides 0.021406
residual sugar 0.000000
free sulfur dioxide 0.000000
dtype: float64
clf.score(df[cols], df["class"])
0.8758029978586723
The score for this classifier is significantly higher than if one was random guessing, so it makes me wonder if the model is overfitting the data. In order to test this classifier for overfitting I’ll divide the data into a training set and a test set using train_test_split.
# The training set has 60% of the data
X_train, X_test, y_train, y_test = train_test_split(
df[cols], df["class"], train_size= 0.6, random_state= 0
)
# This time I only want to fit using the X_train and y_train data
clf.fit(X_train, y_train)
DecisionTreeClassifier(max_leaf_nodes=15, random_state=126)
# This describes its accuracy for the training data
clf.score(X_train, y_train)
0.9214285714285714
# This descibes its accuracy for the testing data, which it has never seen
clf.score(X_test, y_test)
0.7433155080213903
Comparing the classifier’s scores for the training and testing data, they are quite close even though the score for the training data was higher. This makes me doubt that the model is overfitting but I’m not quite sure.
Next I want to try a different Machine Learning model called K-Nearest Neighbors.
#Instantiate; 16 seems like a good number to try
# I actually initially tried 18 but the score was lower
knc = KNeighborsClassifier(n_neighbors= 16)
#Fit; I use the same input and output features as before
knc.fit(df[cols], df["class"])
KNeighborsClassifier(n_neighbors=16)
knc.score(df[cols], df["class"])
0.721627408993576
This model’s score is still much higher than random guessing, but it is not as good as the DecisionTreeClassifier model.
Next I want to see which and how many wines were missclassfied.
df["pred_knc"] = knc.predict(df[cols])
knc.classes_
array(['bad', 'good'], dtype=object)
# This shows the models confidence in its prediction; note there are various rows that are 50/50
arr = knc.predict_proba(df[cols])
arr
array([[0.5 , 0.5 ],
[0.4375, 0.5625],
[0.75 , 0.25 ],
[0.4375, 0.5625],
[0.875 , 0.125 ],
[0.5625, 0.4375],
[0.5625, 0.4375],
[0.6875, 0.3125],
[0.25 , 0.75 ],
[0.25 , 0.75 ],
[0.375 , 0.625 ],
[0.8125, 0.1875],
[0.1875, 0.8125],
[0.1875, 0.8125],
[0.5 , 0.5 ],
[0.25 , 0.75 ],
[0.5 , 0.5 ],
[0.25 , 0.75 ],
[0.625 , 0.375 ],
[0.375 , 0.625 ],
[0.625 , 0.375 ],
[0.5625, 0.4375],
[0.5625, 0.4375],
[0.75 , 0.25 ],
[0.75 , 0.25 ],
[0.25 , 0.75 ],
[0.625 , 0.375 ],
[0.375 , 0.625 ],
[0.1875, 0.8125],
[0.5625, 0.4375],
[0.25 , 0.75 ],
[0.1875, 0.8125],
[0.4375, 0.5625],
[0.4375, 0.5625],
[0.25 , 0.75 ],
[0.1875, 0.8125],
[0.25 , 0.75 ],
[0.5625, 0.4375],
[0.625 , 0.375 ],
[0.5625, 0.4375],
[0.1875, 0.8125],
[0.3125, 0.6875],
[0.625 , 0.375 ],
[0.375 , 0.625 ],
[0.4375, 0.5625],
[0.375 , 0.625 ],
[0.4375, 0.5625],
[0.5 , 0.5 ],
[0.4375, 0.5625],
[0.25 , 0.75 ],
[0.8125, 0.1875],
[0.375 , 0.625 ],
[0.3125, 0.6875],
[0.375 , 0.625 ],
[0.1875, 0.8125],
[0.5625, 0.4375],
[0.125 , 0.875 ],
[0.1875, 0.8125],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.5 , 0.5 ],
[0.5 , 0.5 ],
[0.4375, 0.5625],
[0.3125, 0.6875],
[0.0625, 0.9375],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.1875, 0.8125],
[0.1875, 0.8125],
[0.5 , 0.5 ],
[0.125 , 0.875 ],
[0.375 , 0.625 ],
[0.5 , 0.5 ],
[0.5 , 0.5 ],
[0.5 , 0.5 ],
[0.1875, 0.8125],
[0.875 , 0.125 ],
[0.4375, 0.5625],
[0.5 , 0.5 ],
[0.3125, 0.6875],
[0.375 , 0.625 ],
[0.6875, 0.3125],
[0.125 , 0.875 ],
[0.25 , 0.75 ],
[0.125 , 0.875 ],
[0.5 , 0.5 ],
[0.3125, 0.6875],
[0.4375, 0.5625],
[0.6875, 0.3125],
[0.6875, 0.3125],
[0.5625, 0.4375],
[0.625 , 0.375 ],
[0.3125, 0.6875],
[0.5 , 0.5 ],
[0.625 , 0.375 ],
[0.25 , 0.75 ],
[0.1875, 0.8125],
[0.25 , 0.75 ],
[0.5625, 0.4375],
[0.625 , 0.375 ],
[0.625 , 0.375 ],
[0.6875, 0.3125],
[0.6875, 0.3125],
[0.375 , 0.625 ],
[0.375 , 0.625 ],
[0.125 , 0.875 ],
[0.4375, 0.5625],
[0.25 , 0.75 ],
[0.3125, 0.6875],
[0.125 , 0.875 ],
[0.0625, 0.9375],
[0.6875, 0.3125],
[0.625 , 0.375 ],
[0.25 , 0.75 ],
[0.3125, 0.6875],
[0.0625, 0.9375],
[0.375 , 0.625 ],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.375 , 0.625 ],
[0.25 , 0.75 ],
[0.125 , 0.875 ],
[0.625 , 0.375 ],
[0.1875, 0.8125],
[0.3125, 0.6875],
[0.1875, 0.8125],
[0.75 , 0.25 ],
[0.25 , 0.75 ],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.4375, 0.5625],
[0.1875, 0.8125],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.1875, 0.8125],
[0.1875, 0.8125],
[0.125 , 0.875 ],
[0.3125, 0.6875],
[0.125 , 0.875 ],
[0.375 , 0.625 ],
[0.375 , 0.625 ],
[0.3125, 0.6875],
[0.375 , 0.625 ],
[0.4375, 0.5625],
[0.5625, 0.4375],
[0.3125, 0.6875],
[0.625 , 0.375 ],
[0.3125, 0.6875],
[0.125 , 0.875 ],
[0.375 , 0.625 ],
[0.125 , 0.875 ],
[0.125 , 0.875 ],
[0.375 , 0.625 ],
[0.1875, 0.8125],
[0.1875, 0.8125],
[0.625 , 0.375 ],
[0.375 , 0.625 ],
[0.8125, 0.1875],
[0.8125, 0.1875],
[0.25 , 0.75 ],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.4375, 0.5625],
[0.5 , 0.5 ],
[0.125 , 0.875 ],
[0.3125, 0.6875],
[0.5 , 0.5 ],
[0.125 , 0.875 ],
[0.3125, 0.6875],
[0.4375, 0.5625],
[0.25 , 0.75 ],
[0.4375, 0.5625],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.375 , 0.625 ],
[0.6875, 0.3125],
[0.375 , 0.625 ],
[0.3125, 0.6875],
[0.4375, 0.5625],
[0.1875, 0.8125],
[0.75 , 0.25 ],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.5 , 0.5 ],
[0.625 , 0.375 ],
[0.75 , 0.25 ],
[0.1875, 0.8125],
[0.5625, 0.4375],
[0.4375, 0.5625],
[0.75 , 0.25 ],
[0.625 , 0.375 ],
[0.375 , 0.625 ],
[0.375 , 0.625 ],
[0.75 , 0.25 ],
[0.375 , 0.625 ],
[0.5 , 0.5 ],
[0.375 , 0.625 ],
[0.3125, 0.6875],
[0.4375, 0.5625],
[0.375 , 0.625 ],
[0.375 , 0.625 ],
[0.375 , 0.625 ],
[0.625 , 0.375 ],
[0.625 , 0.375 ],
[0.5 , 0.5 ],
[0.3125, 0.6875],
[0.5625, 0.4375],
[0.3125, 0.6875],
[0.5625, 0.4375],
[0.5625, 0.4375],
[0.5625, 0.4375],
[0.25 , 0.75 ],
[0.75 , 0.25 ],
[0.8125, 0.1875],
[0.4375, 0.5625],
[0.5 , 0.5 ],
[1. , 0. ],
[0.875 , 0.125 ],
[0.625 , 0.375 ],
[0.6875, 0.3125],
[0.6875, 0.3125],
[0.6875, 0.3125],
[0.375 , 0.625 ],
[0.5625, 0.4375],
[0.6875, 0.3125],
[0.4375, 0.5625],
[0.9375, 0.0625],
[0.6875, 0.3125],
[0.6875, 0.3125],
[0.5 , 0.5 ],
[0.4375, 0.5625],
[0.4375, 0.5625],
[0.625 , 0.375 ],
[0.3125, 0.6875],
[0.375 , 0.625 ],
[0.75 , 0.25 ],
[0.75 , 0.25 ],
[0.75 , 0.25 ],
[0.875 , 0.125 ],
[0.875 , 0.125 ],
[0.75 , 0.25 ],
[0.875 , 0.125 ],
[0.25 , 0.75 ],
[0.9375, 0.0625],
[0.6875, 0.3125],
[0.6875, 0.3125],
[0.75 , 0.25 ],
[0.875 , 0.125 ],
[0.5 , 0.5 ],
[0.875 , 0.125 ],
[0.75 , 0.25 ],
[0.75 , 0.25 ],
[0.6875, 0.3125],
[0.625 , 0.375 ],
[0.5625, 0.4375],
[0.9375, 0.0625],
[0.75 , 0.25 ],
[0.6875, 0.3125],
[0.8125, 0.1875],
[0.5625, 0.4375],
[0.5625, 0.4375],
[0.625 , 0.375 ],
[0.6875, 0.3125],
[0.5625, 0.4375],
[0.375 , 0.625 ],
[0.8125, 0.1875],
[0.875 , 0.125 ],
[0.3125, 0.6875],
[0.75 , 0.25 ],
[0.5625, 0.4375],
[0.75 , 0.25 ],
[0.75 , 0.25 ],
[0.6875, 0.3125],
[0.8125, 0.1875],
[0.875 , 0.125 ],
[0.5625, 0.4375],
[0.875 , 0.125 ],
[0.75 , 0.25 ],
[0.375 , 0.625 ],
[0.8125, 0.1875],
[0.6875, 0.3125],
[0.75 , 0.25 ],
[0.3125, 0.6875],
[0.4375, 0.5625],
[0.6875, 0.3125],
[0.375 , 0.625 ],
[0.25 , 0.75 ],
[0.5625, 0.4375],
[0.5625, 0.4375],
[0.25 , 0.75 ],
[0.9375, 0.0625],
[0.8125, 0.1875],
[0.6875, 0.3125],
[0.875 , 0.125 ],
[0.375 , 0.625 ],
[0.6875, 0.3125],
[0.375 , 0.625 ],
[0.5625, 0.4375],
[0.75 , 0.25 ],
[0.5 , 0.5 ],
[0.875 , 0.125 ],
[0.8125, 0.1875],
[0.6875, 0.3125],
[0.5 , 0.5 ],
[0.25 , 0.75 ],
[0.5 , 0.5 ],
[0.75 , 0.25 ],
[0.875 , 0.125 ],
[0.9375, 0.0625],
[0.8125, 0.1875],
[0.8125, 0.1875],
[0.75 , 0.25 ],
[0.3125, 0.6875],
[0.875 , 0.125 ],
[0.875 , 0.125 ],
[0.9375, 0.0625],
[0.75 , 0.25 ],
[0.375 , 0.625 ],
[0.25 , 0.75 ],
[0.8125, 0.1875],
[0.75 , 0.25 ],
[0.875 , 0.125 ],
[0.4375, 0.5625],
[0.875 , 0.125 ],
[0.5625, 0.4375],
[0.625 , 0.375 ],
[0.375 , 0.625 ],
[0.25 , 0.75 ],
[0.5 , 0.5 ],
[0.375 , 0.625 ],
[0.8125, 0.1875],
[0.8125, 0.1875],
[0.875 , 0.125 ],
[0.5625, 0.4375],
[0.375 , 0.625 ],
[0.5625, 0.4375],
[0.5625, 0.4375],
[0.5625, 0.4375],
[0.8125, 0.1875],
[0.5 , 0.5 ],
[0.8125, 0.1875],
[0.6875, 0.3125],
[0.6875, 0.3125],
[0.8125, 0.1875],
[0.75 , 0.25 ],
[1. , 0. ],
[0.875 , 0.125 ],
[0.3125, 0.6875],
[0.875 , 0.125 ],
[0.6875, 0.3125],
[0.4375, 0.5625],
[0.8125, 0.1875],
[0.5 , 0.5 ],
[0.8125, 0.1875],
[0.875 , 0.125 ],
[0.8125, 0.1875],
[0.375 , 0.625 ],
[0.6875, 0.3125],
[0.375 , 0.625 ],
[1. , 0. ],
[0.3125, 0.6875],
[0.5 , 0.5 ],
[0.75 , 0.25 ],
[0.75 , 0.25 ],
[0.75 , 0.25 ],
[0.875 , 0.125 ],
[0.75 , 0.25 ],
[0.8125, 0.1875],
[0.375 , 0.625 ],
[0.75 , 0.25 ],
[0.5625, 0.4375],
[0.4375, 0.5625],
[0.5625, 0.4375],
[0.6875, 0.3125],
[0.75 , 0.25 ],
[0.3125, 0.6875],
[0.3125, 0.6875],
[0.625 , 0.375 ],
[0.4375, 0.5625],
[0.75 , 0.25 ],
[0.8125, 0.1875],
[0.75 , 0.25 ],
[0.875 , 0.125 ],
[0.5625, 0.4375],
[0.5625, 0.4375],
[0.875 , 0.125 ],
[0.5625, 0.4375],
[0.875 , 0.125 ],
[0.8125, 0.1875],
[0.375 , 0.625 ],
[0.1875, 0.8125],
[0.4375, 0.5625],
[1. , 0. ],
[0.875 , 0.125 ],
[0.625 , 0.375 ],
[0.4375, 0.5625],
[0.625 , 0.375 ],
[0.5625, 0.4375],
[0.3125, 0.6875],
[0.8125, 0.1875],
[0.6875, 0.3125],
[0.4375, 0.5625],
[0.875 , 0.125 ],
[0.4375, 0.5625],
[0.6875, 0.3125],
[0.5625, 0.4375],
[0.75 , 0.25 ],
[0.75 , 0.25 ],
[0.5625, 0.4375],
[1. , 0. ],
[0.875 , 0.125 ],
[0.625 , 0.375 ],
[0.375 , 0.625 ],
[0.5625, 0.4375],
[0.75 , 0.25 ],
[0.8125, 0.1875],
[0.5625, 0.4375],
[0.375 , 0.625 ],
[0.5625, 0.4375],
[1. , 0. ],
[0.4375, 0.5625],
[0.5 , 0.5 ],
[0.5 , 0.5 ],
[0.4375, 0.5625],
[0.375 , 0.625 ],
[0.625 , 0.375 ],
[0.4375, 0.5625],
[0.875 , 0.125 ],
[0.8125, 0.1875],
[0.4375, 0.5625],
[0.6875, 0.3125],
[0.875 , 0.125 ],
[0.75 , 0.25 ],
[0.6875, 0.3125],
[0.875 , 0.125 ],
[0.75 , 0.25 ],
[0.8125, 0.1875],
[0.5625, 0.4375],
[0.375 , 0.625 ],
[0.875 , 0.125 ],
[0.8125, 0.1875],
[0.875 , 0.125 ],
[0.625 , 0.375 ],
[0.5625, 0.4375],
[0.875 , 0.125 ],
[0.625 , 0.375 ],
[0.375 , 0.625 ],
[0.375 , 0.625 ],
[0.75 , 0.25 ],
[0.375 , 0.625 ],
[0.5625, 0.4375],
[0.875 , 0.125 ],
[0.6875, 0.3125],
[0.625 , 0.375 ],
[0.8125, 0.1875],
[1. , 0. ],
[0.75 , 0.25 ],
[0.5625, 0.4375],
[0.6875, 0.3125],
[0.4375, 0.5625],
[0.5 , 0.5 ]])
# Here I add the probabilities into df as new columns
df["knc_bad_proba"] = arr[:,0]
df["knc_good_proba"] = arr[:,1]
# Here I make a smaller dataframe containing all the rows where the class probability was 50/50
df2 = df[df["knc_bad_proba"] == 0.5]
df2
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | quality | class | pred_knc | knc_bad_proba | knc_good_proba | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
7 | 7.3 | 0.650 | 0.00 | 1.20 | 0.065 | 15.0 | 21.0 | 0.99460 | 3.39 | 0.47 | 10.0 | 7 | good | bad | 0.5 | 0.5 |
259 | 10.0 | 0.310 | 0.47 | 2.60 | 0.085 | 14.0 | 33.0 | 0.99965 | 3.36 | 0.80 | 10.5 | 7 | good | bad | 0.5 | 0.5 |
267 | 7.9 | 0.350 | 0.46 | 3.60 | 0.078 | 15.0 | 37.0 | 0.99730 | 3.35 | 0.86 | 12.8 | 8 | good | bad | 0.5 | 0.5 |
440 | 12.6 | 0.310 | 0.72 | 2.20 | 0.072 | 6.0 | 29.0 | 0.99870 | 2.88 | 0.82 | 9.8 | 8 | good | bad | 0.5 | 0.5 |
501 | 10.4 | 0.440 | 0.73 | 6.55 | 0.074 | 38.0 | 76.0 | 0.99900 | 3.17 | 0.85 | 12.0 | 7 | good | bad | 0.5 | 0.5 |
502 | 10.4 | 0.440 | 0.73 | 6.55 | 0.074 | 38.0 | 76.0 | 0.99900 | 3.17 | 0.85 | 12.0 | 7 | good | bad | 0.5 | 0.5 |
538 | 12.9 | 0.350 | 0.49 | 5.80 | 0.066 | 5.0 | 35.0 | 1.00140 | 3.20 | 0.66 | 12.0 | 7 | good | bad | 0.5 | 0.5 |
586 | 11.1 | 0.310 | 0.49 | 2.70 | 0.094 | 16.0 | 47.0 | 0.99860 | 3.12 | 1.02 | 10.6 | 7 | good | bad | 0.5 | 0.5 |
588 | 5.0 | 0.420 | 0.24 | 2.00 | 0.060 | 19.0 | 50.0 | 0.99170 | 3.72 | 0.74 | 14.0 | 8 | good | bad | 0.5 | 0.5 |
589 | 10.2 | 0.290 | 0.49 | 2.60 | 0.059 | 5.0 | 13.0 | 0.99760 | 3.05 | 0.74 | 10.5 | 7 | good | bad | 0.5 | 0.5 |
648 | 8.7 | 0.480 | 0.30 | 2.80 | 0.066 | 10.0 | 28.0 | 0.99640 | 3.33 | 0.67 | 11.2 | 7 | good | bad | 0.5 | 0.5 |
821 | 4.9 | 0.420 | 0.00 | 2.10 | 0.048 | 16.0 | 42.0 | 0.99154 | 3.71 | 0.74 | 14.0 | 7 | good | bad | 0.5 | 0.5 |
857 | 8.2 | 0.260 | 0.34 | 2.50 | 0.073 | 16.0 | 47.0 | 0.99594 | 3.40 | 0.78 | 11.3 | 7 | good | bad | 0.5 | 0.5 |
1093 | 9.2 | 0.310 | 0.36 | 2.20 | 0.079 | 11.0 | 31.0 | 0.99615 | 3.33 | 0.86 | 12.0 | 7 | good | bad | 0.5 | 0.5 |
1111 | 5.4 | 0.420 | 0.27 | 2.00 | 0.092 | 23.0 | 55.0 | 0.99471 | 3.78 | 0.64 | 12.3 | 7 | good | bad | 0.5 | 0.5 |
1209 | 6.2 | 0.390 | 0.43 | 2.00 | 0.071 | 14.0 | 24.0 | 0.99428 | 3.45 | 0.87 | 11.2 | 7 | good | bad | 0.5 | 0.5 |
1449 | 7.2 | 0.380 | 0.31 | 2.00 | 0.056 | 15.0 | 29.0 | 0.99472 | 3.23 | 0.76 | 11.3 | 8 | good | bad | 0.5 | 0.5 |
1494 | 6.4 | 0.310 | 0.09 | 1.40 | 0.066 | 15.0 | 28.0 | 0.99459 | 3.42 | 0.70 | 10.0 | 7 | good | bad | 0.5 | 0.5 |
29 | 7.8 | 0.645 | 0.00 | 2.00 | 0.082 | 8.0 | 16.0 | 0.99640 | 3.38 | 0.59 | 9.8 | 6 | bad | bad | 0.5 | 0.5 |
1573 | 6.0 | 0.580 | 0.20 | 2.40 | 0.075 | 15.0 | 50.0 | 0.99467 | 3.58 | 0.67 | 12.5 | 6 | bad | bad | 0.5 | 0.5 |
1423 | 6.4 | 0.530 | 0.09 | 3.90 | 0.123 | 14.0 | 31.0 | 0.99680 | 3.50 | 0.67 | 11.0 | 4 | bad | bad | 0.5 | 0.5 |
683 | 8.1 | 0.780 | 0.23 | 2.60 | 0.059 | 5.0 | 15.0 | 0.99700 | 3.37 | 0.56 | 11.3 | 5 | bad | bad | 0.5 | 0.5 |
1266 | 7.2 | 0.570 | 0.05 | 2.30 | 0.081 | 16.0 | 36.0 | 0.99564 | 3.38 | 0.60 | 10.3 | 6 | bad | bad | 0.5 | 0.5 |
1428 | 7.8 | 0.640 | 0.00 | 1.90 | 0.072 | 27.0 | 55.0 | 0.99620 | 3.31 | 0.63 | 11.0 | 5 | bad | bad | 0.5 | 0.5 |
663 | 10.1 | 0.280 | 0.46 | 1.80 | 0.050 | 5.0 | 13.0 | 0.99740 | 3.04 | 0.79 | 10.2 | 6 | bad | bad | 0.5 | 0.5 |
41 | 8.8 | 0.610 | 0.30 | 2.80 | 0.088 | 17.0 | 46.0 | 0.99760 | 3.26 | 0.51 | 9.3 | 4 | bad | bad | 0.5 | 0.5 |
1336 | 6.0 | 0.500 | 0.00 | 1.40 | 0.057 | 15.0 | 26.0 | 0.99448 | 3.36 | 0.45 | 9.5 | 5 | bad | bad | 0.5 | 0.5 |
1028 | 7.2 | 0.340 | 0.21 | 2.50 | 0.075 | 41.0 | 68.0 | 0.99586 | 3.37 | 0.54 | 10.1 | 6 | bad | bad | 0.5 | 0.5 |
1219 | 9.0 | 0.390 | 0.40 | 1.30 | 0.044 | 25.0 | 50.0 | 0.99478 | 3.20 | 0.83 | 10.9 | 6 | bad | bad | 0.5 | 0.5 |
1254 | 7.8 | 0.700 | 0.06 | 1.90 | 0.079 | 20.0 | 35.0 | 0.99628 | 3.40 | 0.69 | 10.9 | 5 | bad | bad | 0.5 | 0.5 |
1274 | 7.8 | 0.580 | 0.13 | 2.10 | 0.102 | 17.0 | 36.0 | 0.99440 | 3.24 | 0.53 | 11.2 | 6 | bad | bad | 0.5 | 0.5 |
From this smaller dataframe we see that 31 rows had 50/50 chance according to this model to be classified as “bad” or “good”. Since the first class defined in the classifier is “bad” it predicted that for each wine. However, the class of more than half of these rows was “good” therefore at least 15 wines were misclassified as “bad”.
In order to see how many of each class the model misclassified I want to create a confusion matrix. Based on the above, it seems that the model should have misclassified more “good” wines.
c1 = alt.Chart(df).mark_rect().encode(
x= "class",
y= "pred_knc",
color= alt.Color("count()", scale= alt.Scale(scheme= "redblue"))
)
c2 = alt.Chart(df).mark_text(color= "white").encode(
x= "class",
y= "pred_knc",
text= "count()"
)
(c1 + c2).properties(
height= 250,
width= 250
)
Surely enough, the model misclassfied more “good” wines. df did contain more rows whose class is “bad” than rows whose class is “good” so I feel like that might have influenced the model to classify more wines as “bad”.
Summary#
Overall, I used the wine dataset for classification. The models I focused on were DecisionTreeClassifier and KNeighborsClassifier, where DecisionTreeClassifier resulted being more accurate. Based on the predictions that the KNeighborsClassifier model made and the fact that df had unequal quatities of good and bad wine it was not a very accurate model.
References#
Your code above should include references. Here is some additional space for references.
What is the source of your dataset(s)?
Wine Dataset I found this dataset on Kaggle
List any other references that you found helpful.
K-Nearest Neighbors Here is where I learned how to use KNeighborsClassifier
Worksheet 7 This worksheet helped me with plotting various charts all together
red_wine_classification This is a notebook from Kaggle that gave me a few ideas for analyzing the dataframe
Submission#
Using the Share button at the top right, enable Comment privileges for anyone with a link to the project. Then submit that link on Canvas.
Created in Deepnote