Wine Quality#

Author: Paula Mendez-Lagunas

Course Project, UC Irvine, Math 10, S23

Introduction#

The dataset my project focuses on is about wine and it contains columns that describe chemical properties related to each wine and a column which assigns it a quality score. I decided to use classification machine learning models on the data after refining the original dataframe. The two models I chose to use are DecicionTreeClassifier and KNeighborsClassifier.

Importing#

import pandas as pd
import altair as alt

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

Section 1: Preparing the Data#

In this section I use pandas to gain information about the original dataframe and create a new column named Class indicating whether the wine is good or bad based on the quality score it received. Finally I create a new dataframe with balanced values of good and bad wines.

df_temp = pd.read_csv("winequality-red.csv")
# Checking if there are any columns with missing values; notice all columns are numerical
df_temp.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1599 entries, 0 to 1598
Data columns (total 12 columns):
 #   Column                Non-Null Count  Dtype  
---  ------                --------------  -----  
 0   fixed acidity         1599 non-null   float64
 1   volatile acidity      1599 non-null   float64
 2   citric acid           1599 non-null   float64
 3   residual sugar        1599 non-null   float64
 4   chlorides             1599 non-null   float64
 5   free sulfur dioxide   1599 non-null   float64
 6   total sulfur dioxide  1599 non-null   float64
 7   density               1599 non-null   float64
 8   pH                    1599 non-null   float64
 9   sulphates             1599 non-null   float64
 10  alcohol               1599 non-null   float64
 11  quality               1599 non-null   int64  
dtypes: float64(11), int64(1)
memory usage: 150.0 KB

I actually want this to be a classification problem so I want to add a new column that indicates whether the wine’s quality is good or bad. I do this using a lambda function and map.

df_temp["quality"].max()
8
df_temp["class"] = df_temp["quality"].map(lambda x: "good" if x>6 else "bad")
# This helps visualize the proportion of good wine to bad wine in the original dataframe
alt.Chart(df_temp).mark_bar().encode(
    x = "class",
    y = "count()",
    tooltip = ["count()"]
)

In order to create a more balanced DataFrame I will get 250 random rows whose class is bad and then concatenate it to a dataframe that has all the rows whose class is good.

# I used a random state to get reproducible results
df_good = df_temp[df_temp["class"] == "good"]
df_bad = df_temp[df_temp["class"] == "bad"].sample(250, random_state= 97)
# I use axis = 0 since I want to join them along their rows
df = pd.concat((df_good, df_bad), axis= 0)
df.shape
(467, 13)

Therefore my final dataframe is named df and we can see that it now contains 467 rows (of which 250 are labeled “bad” wine) and 13 columns (the 12 original and one we added named “class”).

Section 2: Visualizing Data Relations#

In this section I mainly use altair to create charts for specific data relations that I want to see. Also in order to avoid rewriting the same code I write a fuction called ‘make_chart’.

# This helps me understand the distribution of the data in each column
df.describe()
fixed acidity volatile acidity citric acid residual sugar chlorides free sulfur dioxide total sulfur dioxide density pH sulphates alcohol quality
count 467.000000 467.000000 467.000000 467.000000 467.000000 467.000000 467.000000 467.000000 467.000000 467.000000 467.000000 467.000000
mean 8.438972 0.486895 0.297559 2.588544 0.083615 14.706638 41.557816 0.996456 3.307323 0.686574 10.852819 6.179872
std 1.778647 0.177284 0.203943 1.415962 0.042843 9.873047 32.468504 0.002044 0.147025 0.158632 1.171713 0.977088
min 4.600000 0.120000 0.000000 0.900000 0.012000 1.000000 7.000000 0.990070 2.880000 0.370000 8.700000 3.000000
25% 7.200000 0.350000 0.100000 1.900000 0.066000 7.000000 20.000000 0.995160 3.210000 0.570000 9.800000 5.000000
50% 8.100000 0.460000 0.320000 2.200000 0.077000 12.000000 31.000000 0.996430 3.300000 0.670000 10.800000 6.000000
75% 9.550000 0.605000 0.450000 2.600000 0.089000 19.000000 52.000000 0.997700 3.390000 0.770000 11.700000 7.000000
max 15.600000 1.330000 0.790000 15.400000 0.467000 55.000000 289.000000 1.003690 3.900000 1.560000 14.000000 8.000000

From the information above I can see that some columns have larger scales (for example total sulfur dioxide) while some columns have smaller scales (such as density). Also, by comparing the mean, standard deviation, and max value of each column it seems that most columns have outliers.

#The following code is an adaptation of a code that was used in Worksheet 7
def make_chart(col):
    return alt.Chart(df).mark_circle().encode(
        x= alt.X("quality", scale= alt.Scale(zero= False)),
        y= alt.Y(col, scale= alt.Scale(zero= False)),
        color= "class",
    )

Using the code above, I want to make a chart comparing each input feature column to the quality. Furthermore, since from above we know that each column has a different scale I chose to include ‘zero= False’ in the code so that each chart could have its own scale.

#Only want the first 11 because those are input features and the other 2 cols are output features
cols = [col for col in df.columns[:11]]
chart_list = [make_chart(col) for col in cols]
# This code was also taken from Worksheet 7
total_chart = alt.vconcat(*chart_list)
total_chart

Based on the charts above, almost every single one has outlier(s) although it is most noticeable for residual sugar, chlorides, and total sulfure dioxide. Furthermore we can see that only using one input feature or column would be hard to help classify the wine since there are no clear patterns in these charts.

Section 3: Machine Learning Models#

In this final section I apply the DecisionTreeClassifier and the KNeighborsClassifier models and compare their accuracy. I also use train_test_split to test the DecisionTreeClassifier model for overfitting. For the KNeighborsClassifier model I create a confusion matrix to see its prediction results.

# Instantiate; I decided to use 15 max leaf nodes because there are 11 input variables and 15>11
clf = DecisionTreeClassifier(max_leaf_nodes= 15, random_state= 126)
#Fit; I am using all 11 feature columns and want to predict whether the wine is "good" or "bad"
clf.fit(df[cols], df["class"])
DecisionTreeClassifier(max_leaf_nodes=15, random_state=126)
# This reveals that the most influencing feature is Alcohol
# The 3 columns I said had most noticeable outliers are at the bottom of this chart, I wonder why
pd.Series(clf.feature_importances_ , clf.feature_names_in_).sort_values(ascending= False)
alcohol                 0.507042
volatile acidity        0.171973
sulphates               0.124870
fixed acidity           0.065348
citric acid             0.034490
total sulfur dioxide    0.031346
density                 0.022035
pH                      0.021488
chlorides               0.021406
residual sugar          0.000000
free sulfur dioxide     0.000000
dtype: float64
clf.score(df[cols], df["class"])
0.8758029978586723

The score for this classifier is significantly higher than if one was random guessing, so it makes me wonder if the model is overfitting the data. In order to test this classifier for overfitting I’ll divide the data into a training set and a test set using train_test_split.

# The training set has 60% of the data
X_train, X_test, y_train, y_test = train_test_split(
    df[cols], df["class"], train_size= 0.6, random_state= 0
)
# This time I only want to fit using the X_train and y_train data
clf.fit(X_train, y_train)
DecisionTreeClassifier(max_leaf_nodes=15, random_state=126)
# This describes its accuracy for the training data
clf.score(X_train, y_train)
0.9214285714285714
# This descibes its accuracy for the testing data, which it has never seen
clf.score(X_test, y_test)
0.7433155080213903

Comparing the classifier’s scores for the training and testing data, they are quite close even though the score for the training data was higher. This makes me doubt that the model is overfitting but I’m not quite sure.

Next I want to try a different Machine Learning model called K-Nearest Neighbors.

#Instantiate; 16 seems like a good number to try
# I actually initially tried 18 but the score was lower
knc = KNeighborsClassifier(n_neighbors= 16)
#Fit; I use the same input and output features as before
knc.fit(df[cols], df["class"])
KNeighborsClassifier(n_neighbors=16)
knc.score(df[cols], df["class"])
0.721627408993576

This model’s score is still much higher than random guessing, but it is not as good as the DecisionTreeClassifier model.

Next I want to see which and how many wines were missclassfied.

df["pred_knc"] = knc.predict(df[cols])
knc.classes_
array(['bad', 'good'], dtype=object)
# This shows the models confidence in its prediction; note there are various rows that are 50/50
arr = knc.predict_proba(df[cols])
arr
array([[0.5   , 0.5   ],
       [0.4375, 0.5625],
       [0.75  , 0.25  ],
       [0.4375, 0.5625],
       [0.875 , 0.125 ],
       [0.5625, 0.4375],
       [0.5625, 0.4375],
       [0.6875, 0.3125],
       [0.25  , 0.75  ],
       [0.25  , 0.75  ],
       [0.375 , 0.625 ],
       [0.8125, 0.1875],
       [0.1875, 0.8125],
       [0.1875, 0.8125],
       [0.5   , 0.5   ],
       [0.25  , 0.75  ],
       [0.5   , 0.5   ],
       [0.25  , 0.75  ],
       [0.625 , 0.375 ],
       [0.375 , 0.625 ],
       [0.625 , 0.375 ],
       [0.5625, 0.4375],
       [0.5625, 0.4375],
       [0.75  , 0.25  ],
       [0.75  , 0.25  ],
       [0.25  , 0.75  ],
       [0.625 , 0.375 ],
       [0.375 , 0.625 ],
       [0.1875, 0.8125],
       [0.5625, 0.4375],
       [0.25  , 0.75  ],
       [0.1875, 0.8125],
       [0.4375, 0.5625],
       [0.4375, 0.5625],
       [0.25  , 0.75  ],
       [0.1875, 0.8125],
       [0.25  , 0.75  ],
       [0.5625, 0.4375],
       [0.625 , 0.375 ],
       [0.5625, 0.4375],
       [0.1875, 0.8125],
       [0.3125, 0.6875],
       [0.625 , 0.375 ],
       [0.375 , 0.625 ],
       [0.4375, 0.5625],
       [0.375 , 0.625 ],
       [0.4375, 0.5625],
       [0.5   , 0.5   ],
       [0.4375, 0.5625],
       [0.25  , 0.75  ],
       [0.8125, 0.1875],
       [0.375 , 0.625 ],
       [0.3125, 0.6875],
       [0.375 , 0.625 ],
       [0.1875, 0.8125],
       [0.5625, 0.4375],
       [0.125 , 0.875 ],
       [0.1875, 0.8125],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.5   , 0.5   ],
       [0.5   , 0.5   ],
       [0.4375, 0.5625],
       [0.3125, 0.6875],
       [0.0625, 0.9375],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.1875, 0.8125],
       [0.1875, 0.8125],
       [0.5   , 0.5   ],
       [0.125 , 0.875 ],
       [0.375 , 0.625 ],
       [0.5   , 0.5   ],
       [0.5   , 0.5   ],
       [0.5   , 0.5   ],
       [0.1875, 0.8125],
       [0.875 , 0.125 ],
       [0.4375, 0.5625],
       [0.5   , 0.5   ],
       [0.3125, 0.6875],
       [0.375 , 0.625 ],
       [0.6875, 0.3125],
       [0.125 , 0.875 ],
       [0.25  , 0.75  ],
       [0.125 , 0.875 ],
       [0.5   , 0.5   ],
       [0.3125, 0.6875],
       [0.4375, 0.5625],
       [0.6875, 0.3125],
       [0.6875, 0.3125],
       [0.5625, 0.4375],
       [0.625 , 0.375 ],
       [0.3125, 0.6875],
       [0.5   , 0.5   ],
       [0.625 , 0.375 ],
       [0.25  , 0.75  ],
       [0.1875, 0.8125],
       [0.25  , 0.75  ],
       [0.5625, 0.4375],
       [0.625 , 0.375 ],
       [0.625 , 0.375 ],
       [0.6875, 0.3125],
       [0.6875, 0.3125],
       [0.375 , 0.625 ],
       [0.375 , 0.625 ],
       [0.125 , 0.875 ],
       [0.4375, 0.5625],
       [0.25  , 0.75  ],
       [0.3125, 0.6875],
       [0.125 , 0.875 ],
       [0.0625, 0.9375],
       [0.6875, 0.3125],
       [0.625 , 0.375 ],
       [0.25  , 0.75  ],
       [0.3125, 0.6875],
       [0.0625, 0.9375],
       [0.375 , 0.625 ],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.375 , 0.625 ],
       [0.25  , 0.75  ],
       [0.125 , 0.875 ],
       [0.625 , 0.375 ],
       [0.1875, 0.8125],
       [0.3125, 0.6875],
       [0.1875, 0.8125],
       [0.75  , 0.25  ],
       [0.25  , 0.75  ],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.4375, 0.5625],
       [0.1875, 0.8125],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.1875, 0.8125],
       [0.1875, 0.8125],
       [0.125 , 0.875 ],
       [0.3125, 0.6875],
       [0.125 , 0.875 ],
       [0.375 , 0.625 ],
       [0.375 , 0.625 ],
       [0.3125, 0.6875],
       [0.375 , 0.625 ],
       [0.4375, 0.5625],
       [0.5625, 0.4375],
       [0.3125, 0.6875],
       [0.625 , 0.375 ],
       [0.3125, 0.6875],
       [0.125 , 0.875 ],
       [0.375 , 0.625 ],
       [0.125 , 0.875 ],
       [0.125 , 0.875 ],
       [0.375 , 0.625 ],
       [0.1875, 0.8125],
       [0.1875, 0.8125],
       [0.625 , 0.375 ],
       [0.375 , 0.625 ],
       [0.8125, 0.1875],
       [0.8125, 0.1875],
       [0.25  , 0.75  ],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.4375, 0.5625],
       [0.5   , 0.5   ],
       [0.125 , 0.875 ],
       [0.3125, 0.6875],
       [0.5   , 0.5   ],
       [0.125 , 0.875 ],
       [0.3125, 0.6875],
       [0.4375, 0.5625],
       [0.25  , 0.75  ],
       [0.4375, 0.5625],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.375 , 0.625 ],
       [0.6875, 0.3125],
       [0.375 , 0.625 ],
       [0.3125, 0.6875],
       [0.4375, 0.5625],
       [0.1875, 0.8125],
       [0.75  , 0.25  ],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.5   , 0.5   ],
       [0.625 , 0.375 ],
       [0.75  , 0.25  ],
       [0.1875, 0.8125],
       [0.5625, 0.4375],
       [0.4375, 0.5625],
       [0.75  , 0.25  ],
       [0.625 , 0.375 ],
       [0.375 , 0.625 ],
       [0.375 , 0.625 ],
       [0.75  , 0.25  ],
       [0.375 , 0.625 ],
       [0.5   , 0.5   ],
       [0.375 , 0.625 ],
       [0.3125, 0.6875],
       [0.4375, 0.5625],
       [0.375 , 0.625 ],
       [0.375 , 0.625 ],
       [0.375 , 0.625 ],
       [0.625 , 0.375 ],
       [0.625 , 0.375 ],
       [0.5   , 0.5   ],
       [0.3125, 0.6875],
       [0.5625, 0.4375],
       [0.3125, 0.6875],
       [0.5625, 0.4375],
       [0.5625, 0.4375],
       [0.5625, 0.4375],
       [0.25  , 0.75  ],
       [0.75  , 0.25  ],
       [0.8125, 0.1875],
       [0.4375, 0.5625],
       [0.5   , 0.5   ],
       [1.    , 0.    ],
       [0.875 , 0.125 ],
       [0.625 , 0.375 ],
       [0.6875, 0.3125],
       [0.6875, 0.3125],
       [0.6875, 0.3125],
       [0.375 , 0.625 ],
       [0.5625, 0.4375],
       [0.6875, 0.3125],
       [0.4375, 0.5625],
       [0.9375, 0.0625],
       [0.6875, 0.3125],
       [0.6875, 0.3125],
       [0.5   , 0.5   ],
       [0.4375, 0.5625],
       [0.4375, 0.5625],
       [0.625 , 0.375 ],
       [0.3125, 0.6875],
       [0.375 , 0.625 ],
       [0.75  , 0.25  ],
       [0.75  , 0.25  ],
       [0.75  , 0.25  ],
       [0.875 , 0.125 ],
       [0.875 , 0.125 ],
       [0.75  , 0.25  ],
       [0.875 , 0.125 ],
       [0.25  , 0.75  ],
       [0.9375, 0.0625],
       [0.6875, 0.3125],
       [0.6875, 0.3125],
       [0.75  , 0.25  ],
       [0.875 , 0.125 ],
       [0.5   , 0.5   ],
       [0.875 , 0.125 ],
       [0.75  , 0.25  ],
       [0.75  , 0.25  ],
       [0.6875, 0.3125],
       [0.625 , 0.375 ],
       [0.5625, 0.4375],
       [0.9375, 0.0625],
       [0.75  , 0.25  ],
       [0.6875, 0.3125],
       [0.8125, 0.1875],
       [0.5625, 0.4375],
       [0.5625, 0.4375],
       [0.625 , 0.375 ],
       [0.6875, 0.3125],
       [0.5625, 0.4375],
       [0.375 , 0.625 ],
       [0.8125, 0.1875],
       [0.875 , 0.125 ],
       [0.3125, 0.6875],
       [0.75  , 0.25  ],
       [0.5625, 0.4375],
       [0.75  , 0.25  ],
       [0.75  , 0.25  ],
       [0.6875, 0.3125],
       [0.8125, 0.1875],
       [0.875 , 0.125 ],
       [0.5625, 0.4375],
       [0.875 , 0.125 ],
       [0.75  , 0.25  ],
       [0.375 , 0.625 ],
       [0.8125, 0.1875],
       [0.6875, 0.3125],
       [0.75  , 0.25  ],
       [0.3125, 0.6875],
       [0.4375, 0.5625],
       [0.6875, 0.3125],
       [0.375 , 0.625 ],
       [0.25  , 0.75  ],
       [0.5625, 0.4375],
       [0.5625, 0.4375],
       [0.25  , 0.75  ],
       [0.9375, 0.0625],
       [0.8125, 0.1875],
       [0.6875, 0.3125],
       [0.875 , 0.125 ],
       [0.375 , 0.625 ],
       [0.6875, 0.3125],
       [0.375 , 0.625 ],
       [0.5625, 0.4375],
       [0.75  , 0.25  ],
       [0.5   , 0.5   ],
       [0.875 , 0.125 ],
       [0.8125, 0.1875],
       [0.6875, 0.3125],
       [0.5   , 0.5   ],
       [0.25  , 0.75  ],
       [0.5   , 0.5   ],
       [0.75  , 0.25  ],
       [0.875 , 0.125 ],
       [0.9375, 0.0625],
       [0.8125, 0.1875],
       [0.8125, 0.1875],
       [0.75  , 0.25  ],
       [0.3125, 0.6875],
       [0.875 , 0.125 ],
       [0.875 , 0.125 ],
       [0.9375, 0.0625],
       [0.75  , 0.25  ],
       [0.375 , 0.625 ],
       [0.25  , 0.75  ],
       [0.8125, 0.1875],
       [0.75  , 0.25  ],
       [0.875 , 0.125 ],
       [0.4375, 0.5625],
       [0.875 , 0.125 ],
       [0.5625, 0.4375],
       [0.625 , 0.375 ],
       [0.375 , 0.625 ],
       [0.25  , 0.75  ],
       [0.5   , 0.5   ],
       [0.375 , 0.625 ],
       [0.8125, 0.1875],
       [0.8125, 0.1875],
       [0.875 , 0.125 ],
       [0.5625, 0.4375],
       [0.375 , 0.625 ],
       [0.5625, 0.4375],
       [0.5625, 0.4375],
       [0.5625, 0.4375],
       [0.8125, 0.1875],
       [0.5   , 0.5   ],
       [0.8125, 0.1875],
       [0.6875, 0.3125],
       [0.6875, 0.3125],
       [0.8125, 0.1875],
       [0.75  , 0.25  ],
       [1.    , 0.    ],
       [0.875 , 0.125 ],
       [0.3125, 0.6875],
       [0.875 , 0.125 ],
       [0.6875, 0.3125],
       [0.4375, 0.5625],
       [0.8125, 0.1875],
       [0.5   , 0.5   ],
       [0.8125, 0.1875],
       [0.875 , 0.125 ],
       [0.8125, 0.1875],
       [0.375 , 0.625 ],
       [0.6875, 0.3125],
       [0.375 , 0.625 ],
       [1.    , 0.    ],
       [0.3125, 0.6875],
       [0.5   , 0.5   ],
       [0.75  , 0.25  ],
       [0.75  , 0.25  ],
       [0.75  , 0.25  ],
       [0.875 , 0.125 ],
       [0.75  , 0.25  ],
       [0.8125, 0.1875],
       [0.375 , 0.625 ],
       [0.75  , 0.25  ],
       [0.5625, 0.4375],
       [0.4375, 0.5625],
       [0.5625, 0.4375],
       [0.6875, 0.3125],
       [0.75  , 0.25  ],
       [0.3125, 0.6875],
       [0.3125, 0.6875],
       [0.625 , 0.375 ],
       [0.4375, 0.5625],
       [0.75  , 0.25  ],
       [0.8125, 0.1875],
       [0.75  , 0.25  ],
       [0.875 , 0.125 ],
       [0.5625, 0.4375],
       [0.5625, 0.4375],
       [0.875 , 0.125 ],
       [0.5625, 0.4375],
       [0.875 , 0.125 ],
       [0.8125, 0.1875],
       [0.375 , 0.625 ],
       [0.1875, 0.8125],
       [0.4375, 0.5625],
       [1.    , 0.    ],
       [0.875 , 0.125 ],
       [0.625 , 0.375 ],
       [0.4375, 0.5625],
       [0.625 , 0.375 ],
       [0.5625, 0.4375],
       [0.3125, 0.6875],
       [0.8125, 0.1875],
       [0.6875, 0.3125],
       [0.4375, 0.5625],
       [0.875 , 0.125 ],
       [0.4375, 0.5625],
       [0.6875, 0.3125],
       [0.5625, 0.4375],
       [0.75  , 0.25  ],
       [0.75  , 0.25  ],
       [0.5625, 0.4375],
       [1.    , 0.    ],
       [0.875 , 0.125 ],
       [0.625 , 0.375 ],
       [0.375 , 0.625 ],
       [0.5625, 0.4375],
       [0.75  , 0.25  ],
       [0.8125, 0.1875],
       [0.5625, 0.4375],
       [0.375 , 0.625 ],
       [0.5625, 0.4375],
       [1.    , 0.    ],
       [0.4375, 0.5625],
       [0.5   , 0.5   ],
       [0.5   , 0.5   ],
       [0.4375, 0.5625],
       [0.375 , 0.625 ],
       [0.625 , 0.375 ],
       [0.4375, 0.5625],
       [0.875 , 0.125 ],
       [0.8125, 0.1875],
       [0.4375, 0.5625],
       [0.6875, 0.3125],
       [0.875 , 0.125 ],
       [0.75  , 0.25  ],
       [0.6875, 0.3125],
       [0.875 , 0.125 ],
       [0.75  , 0.25  ],
       [0.8125, 0.1875],
       [0.5625, 0.4375],
       [0.375 , 0.625 ],
       [0.875 , 0.125 ],
       [0.8125, 0.1875],
       [0.875 , 0.125 ],
       [0.625 , 0.375 ],
       [0.5625, 0.4375],
       [0.875 , 0.125 ],
       [0.625 , 0.375 ],
       [0.375 , 0.625 ],
       [0.375 , 0.625 ],
       [0.75  , 0.25  ],
       [0.375 , 0.625 ],
       [0.5625, 0.4375],
       [0.875 , 0.125 ],
       [0.6875, 0.3125],
       [0.625 , 0.375 ],
       [0.8125, 0.1875],
       [1.    , 0.    ],
       [0.75  , 0.25  ],
       [0.5625, 0.4375],
       [0.6875, 0.3125],
       [0.4375, 0.5625],
       [0.5   , 0.5   ]])
# Here I add the probabilities into df as new columns
df["knc_bad_proba"] = arr[:,0]
df["knc_good_proba"] = arr[:,1]
# Here I make a smaller dataframe containing all the rows where the class probability was 50/50
df2 = df[df["knc_bad_proba"] == 0.5]
df2
fixed acidity volatile acidity citric acid residual sugar chlorides free sulfur dioxide total sulfur dioxide density pH sulphates alcohol quality class pred_knc knc_bad_proba knc_good_proba
7 7.3 0.650 0.00 1.20 0.065 15.0 21.0 0.99460 3.39 0.47 10.0 7 good bad 0.5 0.5
259 10.0 0.310 0.47 2.60 0.085 14.0 33.0 0.99965 3.36 0.80 10.5 7 good bad 0.5 0.5
267 7.9 0.350 0.46 3.60 0.078 15.0 37.0 0.99730 3.35 0.86 12.8 8 good bad 0.5 0.5
440 12.6 0.310 0.72 2.20 0.072 6.0 29.0 0.99870 2.88 0.82 9.8 8 good bad 0.5 0.5
501 10.4 0.440 0.73 6.55 0.074 38.0 76.0 0.99900 3.17 0.85 12.0 7 good bad 0.5 0.5
502 10.4 0.440 0.73 6.55 0.074 38.0 76.0 0.99900 3.17 0.85 12.0 7 good bad 0.5 0.5
538 12.9 0.350 0.49 5.80 0.066 5.0 35.0 1.00140 3.20 0.66 12.0 7 good bad 0.5 0.5
586 11.1 0.310 0.49 2.70 0.094 16.0 47.0 0.99860 3.12 1.02 10.6 7 good bad 0.5 0.5
588 5.0 0.420 0.24 2.00 0.060 19.0 50.0 0.99170 3.72 0.74 14.0 8 good bad 0.5 0.5
589 10.2 0.290 0.49 2.60 0.059 5.0 13.0 0.99760 3.05 0.74 10.5 7 good bad 0.5 0.5
648 8.7 0.480 0.30 2.80 0.066 10.0 28.0 0.99640 3.33 0.67 11.2 7 good bad 0.5 0.5
821 4.9 0.420 0.00 2.10 0.048 16.0 42.0 0.99154 3.71 0.74 14.0 7 good bad 0.5 0.5
857 8.2 0.260 0.34 2.50 0.073 16.0 47.0 0.99594 3.40 0.78 11.3 7 good bad 0.5 0.5
1093 9.2 0.310 0.36 2.20 0.079 11.0 31.0 0.99615 3.33 0.86 12.0 7 good bad 0.5 0.5
1111 5.4 0.420 0.27 2.00 0.092 23.0 55.0 0.99471 3.78 0.64 12.3 7 good bad 0.5 0.5
1209 6.2 0.390 0.43 2.00 0.071 14.0 24.0 0.99428 3.45 0.87 11.2 7 good bad 0.5 0.5
1449 7.2 0.380 0.31 2.00 0.056 15.0 29.0 0.99472 3.23 0.76 11.3 8 good bad 0.5 0.5
1494 6.4 0.310 0.09 1.40 0.066 15.0 28.0 0.99459 3.42 0.70 10.0 7 good bad 0.5 0.5
29 7.8 0.645 0.00 2.00 0.082 8.0 16.0 0.99640 3.38 0.59 9.8 6 bad bad 0.5 0.5
1573 6.0 0.580 0.20 2.40 0.075 15.0 50.0 0.99467 3.58 0.67 12.5 6 bad bad 0.5 0.5
1423 6.4 0.530 0.09 3.90 0.123 14.0 31.0 0.99680 3.50 0.67 11.0 4 bad bad 0.5 0.5
683 8.1 0.780 0.23 2.60 0.059 5.0 15.0 0.99700 3.37 0.56 11.3 5 bad bad 0.5 0.5
1266 7.2 0.570 0.05 2.30 0.081 16.0 36.0 0.99564 3.38 0.60 10.3 6 bad bad 0.5 0.5
1428 7.8 0.640 0.00 1.90 0.072 27.0 55.0 0.99620 3.31 0.63 11.0 5 bad bad 0.5 0.5
663 10.1 0.280 0.46 1.80 0.050 5.0 13.0 0.99740 3.04 0.79 10.2 6 bad bad 0.5 0.5
41 8.8 0.610 0.30 2.80 0.088 17.0 46.0 0.99760 3.26 0.51 9.3 4 bad bad 0.5 0.5
1336 6.0 0.500 0.00 1.40 0.057 15.0 26.0 0.99448 3.36 0.45 9.5 5 bad bad 0.5 0.5
1028 7.2 0.340 0.21 2.50 0.075 41.0 68.0 0.99586 3.37 0.54 10.1 6 bad bad 0.5 0.5
1219 9.0 0.390 0.40 1.30 0.044 25.0 50.0 0.99478 3.20 0.83 10.9 6 bad bad 0.5 0.5
1254 7.8 0.700 0.06 1.90 0.079 20.0 35.0 0.99628 3.40 0.69 10.9 5 bad bad 0.5 0.5
1274 7.8 0.580 0.13 2.10 0.102 17.0 36.0 0.99440 3.24 0.53 11.2 6 bad bad 0.5 0.5

From this smaller dataframe we see that 31 rows had 50/50 chance according to this model to be classified as “bad” or “good”. Since the first class defined in the classifier is “bad” it predicted that for each wine. However, the class of more than half of these rows was “good” therefore at least 15 wines were misclassified as “bad”.

In order to see how many of each class the model misclassified I want to create a confusion matrix. Based on the above, it seems that the model should have misclassified more “good” wines.

c1 = alt.Chart(df).mark_rect().encode(
    x= "class",
    y= "pred_knc",
    color= alt.Color("count()", scale= alt.Scale(scheme= "redblue"))
)

c2 = alt.Chart(df).mark_text(color= "white").encode(
    x= "class",
    y= "pred_knc",
    text= "count()"
)

(c1 + c2).properties(
    height= 250,
    width= 250
)

Surely enough, the model misclassfied more “good” wines. df did contain more rows whose class is “bad” than rows whose class is “good” so I feel like that might have influenced the model to classify more wines as “bad”.

Summary#

Overall, I used the wine dataset for classification. The models I focused on were DecisionTreeClassifier and KNeighborsClassifier, where DecisionTreeClassifier resulted being more accurate. Based on the predictions that the KNeighborsClassifier model made and the fact that df had unequal quatities of good and bad wine it was not a very accurate model.

References#

Your code above should include references. Here is some additional space for references.

  • What is the source of your dataset(s)?

Wine Dataset I found this dataset on Kaggle

  • List any other references that you found helpful.

K-Nearest Neighbors Here is where I learned how to use KNeighborsClassifier

Worksheet 7 This worksheet helped me with plotting various charts all together

red_wine_classification This is a notebook from Kaggle that gave me a few ideas for analyzing the dataframe

Submission#

Using the Share button at the top right, enable Comment privileges for anyone with a link to the project. Then submit that link on Canvas.

Created in deepnote.com Created in Deepnote