Visualization in Python

Quick introduction to Matplotlib

Matplotlib is the most famous visualization library in Python. It has features that are very similar to the plotting features in Matlab. We actually won’t use Matplotlib much in Math 10 (we will use Altair more), so this will just be a short introduction.

There are a few different ways to use Matplotlib (which I think made it difficult for me to learn Matplotlib). The way we will present Matplotlib in Math 10 involves a little more writing but offers more flexibility and is more similar to typical Python code. Here is a relevant quote from the Matplotlib documentation:

We call methods that do the plotting directly from the Axes, which gives us much more flexibility and power in customizing our plot… In general, try to use the object-oriented interface over the pyplot interface.

x = [70,95,60,110,80]
y = [200,150,120,190,120]
import matplotlib.pyplot as plt

For now don’t worry too much about what the line fig, ax = plt.subplots() is doing. You should imagine that plt.subplots() is returning two objects, one which we name fig and one which we name ax. The fig object is what shows the image, and the ax object is where we do the plotting. There is also a version where ax is an array of axes objects.

The line ax.plot(x,y) behaves very similarly to Matlab. We think of x as holding the x-coordinates and y as holding the y-coordinates. Matplotlib then connects the dots with straight lines, just like in Matlab.

fig, ax = plt.subplots()
ax.plot(x,y)
[<matplotlib.lines.Line2D at 0x7fb0ea103bd0>]
../_images/Week2-Wednesday_5_1.png
type(ax)
matplotlib.axes._subplots.AxesSubplot

It’s very similar to make a scatter plot instead of a line plot.

fig, ax = plt.subplots()
ax.scatter(x,y)
<matplotlib.collections.PathCollection at 0x7fb0ea11dc90>
../_images/Week2-Wednesday_8_1.png

The fig object is what displays the image.

fig
../_images/Week2-Wednesday_10_0.png

Using NumPy and Matplotlib to plot cos(x)

Let’s plot y = cos(x) for x between 0 and 2pi. Our first attempt has lots of mistakes.

x1 = range(0,2*pi)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/var/folders/8j/gshrlmtn7dg4qtztj4d4t_w40000gn/T/ipykernel_14988/1058624340.py in <module>
----> 1 x1 = range(0,2*pi)

NameError: name 'pi' is not defined
import numpy as np
np.pi
3.141592653589793

Remember that range can only take integer arguments.

x1 = range(0,2*np.pi)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/8j/gshrlmtn7dg4qtztj4d4t_w40000gn/T/ipykernel_14988/86406984.py in <module>
----> 1 x1 = range(0,2*np.pi)

TypeError: 'float' object cannot be interpreted as an integer

NumPy provides a function very similar to range. The NumPy version is called arange, and it accepts float arguments.

x1 = np.arange(0,2*np.pi)

Python’s range makes objects of type range. NumPy’s arange makes objects of type numpy.ndarray, which is the most common data type in NumPy.

type(x1)
numpy.ndarray

Now let’s try to make the y-coordinates.

y1 = cos(x1)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/var/folders/8j/gshrlmtn7dg4qtztj4d4t_w40000gn/T/ipykernel_14988/1263318031.py in <module>
----> 1 y1 = cos(x1)

NameError: name 'cos' is not defined
y1 = np.cos(x1)
fig, ax = plt.subplots()
ax.plot(x1,y1)
[<matplotlib.lines.Line2D at 0x7fb0d8f3e490>]
../_images/Week2-Wednesday_24_1.png

It doesn’t look very good because the step size is too big. By default, arange (and range) uses a step size of 1.

x1
array([0., 1., 2., 3., 4., 5., 6.])
x1.shape
(7,)

Here we use a step size of 0.01.

x2 = np.arange(0,2*np.pi,0.01)
y2 = np.cos(x2)
fig, ax = plt.subplots()
ax.plot(x2,y2)
[<matplotlib.lines.Line2D at 0x7fb0d9028390>]
../_images/Week2-Wednesday_30_1.png

Plotting based on the Grammar of Graphics

Here we will introduce three more plotting libraries: Altair (the most important for us in Math 10), Seaborn, and Plotly. These three libraries are very similar to each other (not so similar to Matplotlib, although Seaborn is built on top of Matplotlib), and I believe all three are based on a notion called the Grammar of Graphics. (Here is the book The Grammar of Graphics, which is freely available to download from on campus or using VPN.)

Here is the basic setup:

  • We have a pandas DataFrame, and each row in the DataFrame corresponds to one observation (or one data point).

  • Columns in the DataFrame correspond to different variables.

  • To produce the visualizations, we encode different columns from the DataFrame into visual properties of the chart.

import pandas as pd

Making a pandas DataFrame using our same x and y lists from above.

x
[70, 95, 60, 110, 80]
y
[200, 150, 120, 190, 120]
df = pd.DataFrame({"height":x, "weight":y})
df
height weight
0 70 200
1 95 150
2 60 120
3 110 190
4 80 120
df.columns
Index(['height', 'weight'], dtype='object')

Let’s first plot this data using Seaborn, which I think is the most famous member of this family of plotting libraries.

import seaborn as sns

Notice that we specify the pandas DataFrame which holds the data, as well as which column to use for the x-coordinates and which column to use for y-coordinates. By default, Seaborn sorts the data so that the x-coordinates are increasing (that’s why the lines do not cross each other here).

sns.lineplot(
    data=df,
    x="height",
    y="weight"
)
<AxesSubplot:xlabel='height', ylabel='weight'>
../_images/Week2-Wednesday_42_1.png

The term data above is what’s called a keyword argument. It gets temporarily given the value df, but that only happens locally, we do not have any record of that ourselves.

data
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/var/folders/8j/gshrlmtn7dg4qtztj4d4t_w40000gn/T/ipykernel_14988/391604064.py in <module>
----> 1 data

NameError: name 'data' is not defined

Here is the Plotly version of the same plot.

import plotly.express as px

This is very similar. It uses the keyword argument data_frame instead of data, and it does not order the x-values, but otherwise it is quite similar.

px.line(
    data_frame=df,
    x="height",
    y="weight"
)

Here is the Altair version of the same plot. Altair is the plotting library we will use most often in Math 10. The syntax is a little different from the Seaborn and Plotly syntax. It is worth practicing with this Altair syntax until you get comfortable with it.

import altair as alt
alt.Chart(data=df).mark_line().encode(
    x="height",
    y="weight"
)

If we want to change from a line plot to a scatter plot, in Altair, we change mark_line to mark_circle.

alt.Chart(data=df).mark_circle().encode(
    x="height",
    y="weight"
)

Let’s see how to change the domain shown on the x-axis. As an intermediate step, we give the longer version of the same code as above. We change x="height" to x=alt.X("height"). So far, this does not change the chart.

alt.Chart(data=df).mark_circle().encode(
    x=alt.X("height"),
    y="weight"
)

Now we add a new keyword argument inside alt.X which specifies that the x-values should be displayed between -100 and 500.

alt.Chart(data=df).mark_circle().encode(
    x=alt.X("height", scale=alt.Scale(domain=(-100,500))),
    y="weight"
)

If instead you want Altair to choose for you, but not force 0 to be included, you can specify zero=False. Here it zooms in as much as possible. (For now it is only changing the x-values, not the y-values, because we are only using alt.X, not alt.Y.)

alt.Chart(data=df).mark_circle().encode(
    x=alt.X("height", scale=alt.Scale(zero=False)),
    y="weight"
)

Let’s recall what data is in df.

df
height weight
0 70 200
1 95 150
2 60 120
3 110 190
4 80 120

Here is a method for creating a new column in a DataFrame. The most important thing is that the number of values we provide is equal to the number of rows. That’s why we give 5 state names in this case.

df["state"] = ["California","California","Oregon","Nevada","Arizona"]
df
height weight state
0 70 200 California
1 95 150 California
2 60 120 Oregon
3 110 190 Nevada
4 80 120 Arizona

Let’s add one more column, containing floats.

df["GPA"] = [1.5, 2.4, 4, 3.5, 4.7]
df
height weight state GPA
0 70 200 California 1.5
1 95 150 California 2.4
2 60 120 Oregon 4.0
3 110 190 Nevada 3.5
4 80 120 Arizona 4.7

Here we encode the “state” column into the color of the chart, and we encode the “GPA” column into the size of the points.

alt.Chart(data=df).mark_circle().encode(
    x=alt.X("height", scale=alt.Scale(zero=False)),
    y="weight",
    color="state",
    size="GPA"
)

Everything works very similarly in Seaborn and in Plotly. Here is the Seaborn version.

sns.scatterplot(
    data=df,
    x="height",
    y="weight",
    hue="state",
    size="GPA"
)
<AxesSubplot:xlabel='height', ylabel='weight'>
../_images/Week2-Wednesday_71_1.png

Here is the Plotly version.

px.scatter(
    data_frame=df,
    x="height",
    y="weight",
    color="state",
    size="GPA"
)

Let’s access the row at index 2 of the DataFrame. Using df[2] doesn’t work; if we use df[2], pandas is looking for a column with the name 2.

df[2]
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/miniconda3/envs/math10s22/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   3360             try:
-> 3361                 return self._engine.get_loc(casted_key)
   3362             except KeyError as err:

~/miniconda3/envs/math10s22/lib/python3.7/site-packages/pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

~/miniconda3/envs/math10s22/lib/python3.7/site-packages/pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 2

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
/var/folders/8j/gshrlmtn7dg4qtztj4d4t_w40000gn/T/ipykernel_14988/2772902488.py in <module>
----> 1 df[2]

~/miniconda3/envs/math10s22/lib/python3.7/site-packages/pandas/core/frame.py in __getitem__(self, key)
   3456             if self.columns.nlevels > 1:
   3457                 return self._getitem_multilevel(key)
-> 3458             indexer = self.columns.get_loc(key)
   3459             if is_integer(indexer):
   3460                 indexer = [indexer]

~/miniconda3/envs/math10s22/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   3361                 return self._engine.get_loc(casted_key)
   3362             except KeyError as err:
-> 3363                 raise KeyError(key) from err
   3364 
   3365         if is_scalar(key) and isna(key) and not self.hasnans:

KeyError: 2

Here is the right way to find the row at index 2.

We think of each row as corresponding to one observation. In this case, each row could correspond to one person. It’s worth looking at the data in this row, and looking back at the three plots from Altair, Seaborn, and Plotly, and seeing how this data is encoded in one of the 5 data points from these charts. The point is at coordinates (60,120), it is colored according to the state being “Oregon”, and the size corresponds to the GPA value.

df.iloc[2]
height        60
weight       120
state     Oregon
GPA          4.0
Name: 2, dtype: object

If you try to use px.scatter(data=...) you will get an error. It needs to be px.scatter(data_frame=...). This data_frame is a keyword argument to the scatter function, as you can see at the beginning of the documentation.

help(px.scatter)
Help on function scatter in module plotly.express._chart_types:

scatter(data_frame=None, x=None, y=None, color=None, symbol=None, size=None, hover_name=None, hover_data=None, custom_data=None, text=None, facet_row=None, facet_col=None, facet_col_wrap=0, facet_row_spacing=None, facet_col_spacing=None, error_x=None, error_x_minus=None, error_y=None, error_y_minus=None, animation_frame=None, animation_group=None, category_orders=None, labels=None, orientation=None, color_discrete_sequence=None, color_discrete_map=None, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, symbol_sequence=None, symbol_map=None, opacity=None, size_max=None, marginal_x=None, marginal_y=None, trendline=None, trendline_options=None, trendline_color_override=None, trendline_scope='trace', log_x=False, log_y=False, range_x=None, range_y=None, render_mode='auto', title=None, template=None, width=None, height=None)
        In a scatter plot, each row of `data_frame` is represented by a symbol
        mark in 2D space.
        
    Parameters
    ----------
    data_frame: DataFrame or array-like or dict
        This argument needs to be passed for column names (and not keyword
        names) to be used. Array-like and dict are tranformed internally to a
        pandas DataFrame. Optional: if missing, a DataFrame gets constructed
        under the hood using the other arguments.
    x: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        position marks along the x axis in cartesian coordinates. Either `x` or
        `y` can optionally be a list of column references or array_likes,  in
        which case the data will be treated as if it were 'wide' rather than
        'long'.
    y: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        position marks along the y axis in cartesian coordinates. Either `x` or
        `y` can optionally be a list of column references or array_likes,  in
        which case the data will be treated as if it were 'wide' rather than
        'long'.
    color: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        assign color to marks.
    symbol: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        assign symbols to marks.
    size: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        assign mark sizes.
    hover_name: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like appear in bold
        in the hover tooltip.
    hover_data: list of str or int, or Series or array-like, or dict
        Either a list of names of columns in `data_frame`, or pandas Series, or
        array_like objects or a dict with column names as keys, with values
        True (for default formatting) False (in order to remove this column
        from hover information), or a formatting string, for example ':.3f' or
        '|%a' or list-like data to appear in the hover tooltip or tuples with a
        bool or formatting string as first element, and list-like data to
        appear in hover as second element Values from these columns appear as
        extra data in the hover tooltip.
    custom_data: list of str or int, or Series or array-like
        Either names of columns in `data_frame`, or pandas Series, or
        array_like objects Values from these columns are extra data, to be used
        in widgets or Dash callbacks for example. This data is not user-visible
        but is included in events emitted by the figure (lasso selection etc.)
    text: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like appear in the
        figure as text labels.
    facet_row: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        assign marks to facetted subplots in the vertical direction.
    facet_col: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        assign marks to facetted subplots in the horizontal direction.
    facet_col_wrap: int
        Maximum number of facet columns. Wraps the column variable at this
        width, so that the column facets span multiple rows. Ignored if 0, and
        forced to 0 if `facet_row` or a `marginal` is set.
    facet_row_spacing: float between 0 and 1
        Spacing between facet rows, in paper units. Default is 0.03 or 0.0.7
        when facet_col_wrap is used.
    facet_col_spacing: float between 0 and 1
        Spacing between facet columns, in paper units Default is 0.02.
    error_x: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        size x-axis error bars. If `error_x_minus` is `None`, error bars will
        be symmetrical, otherwise `error_x` is used for the positive direction
        only.
    error_x_minus: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        size x-axis error bars in the negative direction. Ignored if `error_x`
        is `None`.
    error_y: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        size y-axis error bars. If `error_y_minus` is `None`, error bars will
        be symmetrical, otherwise `error_y` is used for the positive direction
        only.
    error_y_minus: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        size y-axis error bars in the negative direction. Ignored if `error_y`
        is `None`.
    animation_frame: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        assign marks to animation frames.
    animation_group: str or int or Series or array-like
        Either a name of a column in `data_frame`, or a pandas Series or
        array_like object. Values from this column or array_like are used to
        provide object-constancy across animation frames: rows with matching
        `animation_group`s will be treated as if they describe the same object
        in each frame.
    category_orders: dict with str keys and list of str values (default `{}`)
        By default, in Python 3.6+, the order of categorical values in axes,
        legends and facets depends on the order in which these values are first
        encountered in `data_frame` (and no order is guaranteed by default in
        Python below 3.6). This parameter is used to force a specific ordering
        of values per column. The keys of this dict should correspond to column
        names, and the values should be lists of strings corresponding to the
        specific display order desired.
    labels: dict with str keys and str values (default `{}`)
        By default, column names are used in the figure for axis titles, legend
        entries and hovers. This parameter allows this to be overridden. The
        keys of this dict should correspond to column names, and the values
        should correspond to the desired label to be displayed.
    orientation: str, one of `'h'` for horizontal or `'v'` for vertical. 
        (default `'v'` if `x` and `y` are provided and both continous or both
        categorical,  otherwise `'v'`(`'h'`) if `x`(`y`) is categorical and
        `y`(`x`) is continuous,  otherwise `'v'`(`'h'`) if only `x`(`y`) is
        provided)
    color_discrete_sequence: list of str
        Strings should define valid CSS-colors. When `color` is set and the
        values in the corresponding column are not numeric, values in that
        column are assigned colors by cycling through `color_discrete_sequence`
        in the order described in `category_orders`, unless the value of
        `color` is a key in `color_discrete_map`. Various useful color
        sequences are available in the `plotly.express.colors` submodules,
        specifically `plotly.express.colors.qualitative`.
    color_discrete_map: dict with str keys and str values (default `{}`)
        String values should define valid CSS-colors Used to override
        `color_discrete_sequence` to assign a specific colors to marks
        corresponding with specific values. Keys in `color_discrete_map` should
        be values in the column denoted by `color`. Alternatively, if the
        values of `color` are valid colors, the string `'identity'` may be
        passed to cause them to be used directly.
    color_continuous_scale: list of str
        Strings should define valid CSS-colors This list is used to build a
        continuous color scale when the column denoted by `color` contains
        numeric data. Various useful color scales are available in the
        `plotly.express.colors` submodules, specifically
        `plotly.express.colors.sequential`, `plotly.express.colors.diverging`
        and `plotly.express.colors.cyclical`.
    range_color: list of two numbers
        If provided, overrides auto-scaling on the continuous color scale.
    color_continuous_midpoint: number (default `None`)
        If set, computes the bounds of the continuous color scale to have the
        desired midpoint. Setting this value is recommended when using
        `plotly.express.colors.diverging` color scales as the inputs to
        `color_continuous_scale`.
    symbol_sequence: list of str
        Strings should define valid plotly.js symbols. When `symbol` is set,
        values in that column are assigned symbols by cycling through
        `symbol_sequence` in the order described in `category_orders`, unless
        the value of `symbol` is a key in `symbol_map`.
    symbol_map: dict with str keys and str values (default `{}`)
        String values should define plotly.js symbols Used to override
        `symbol_sequence` to assign a specific symbols to marks corresponding
        with specific values. Keys in `symbol_map` should be values in the
        column denoted by `symbol`. Alternatively, if the values of `symbol`
        are valid symbol names, the string `'identity'` may be passed to cause
        them to be used directly.
    opacity: float
        Value between 0 and 1. Sets the opacity for markers.
    size_max: int (default `20`)
        Set the maximum mark size when using `size`.
    marginal_x: str
        One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`. If set, a
        horizontal subplot is drawn above the main plot, visualizing the
        x-distribution.
    marginal_y: str
        One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`. If set, a
        vertical subplot is drawn to the right of the main plot, visualizing
        the y-distribution.
    trendline: str
        One of `'ols'`, `'lowess'`, `'rolling'`, `'expanding'` or `'ewm'`. If
        `'ols'`, an Ordinary Least Squares regression line will be drawn for
        each discrete-color/symbol group. If `'lowess`', a Locally Weighted
        Scatterplot Smoothing line will be drawn for each discrete-color/symbol
        group. If `'rolling`', a Rolling (e.g. rolling average, rolling median)
        line will be drawn for each discrete-color/symbol group. If
        `'expanding`', an Expanding (e.g. expanding average, expanding sum)
        line will be drawn for each discrete-color/symbol group. If `'ewm`', an
        Exponentially Weighted Moment (e.g. exponentially-weighted moving
        average) line will be drawn for each discrete-color/symbol group. See
        the docstrings for the functions in
        `plotly.express.trendline_functions` for more details on these
        functions and how to configure them with the `trendline_options`
        argument.
    trendline_options: dict
        Options passed as the first argument to the function from
        `plotly.express.trendline_functions`  named in the `trendline`
        argument.
    trendline_color_override: str
        Valid CSS color. If provided, and if `trendline` is set, all trendlines
        will be drawn in this color rather than in the same color as the traces
        from which they draw their inputs.
    trendline_scope: str (one of `'trace'` or `'overall'`, default `'trace'`)
        If `'trace'`, then one trendline is drawn per trace (i.e. per color,
        symbol, facet, animation frame etc) and if `'overall'` then one
        trendline is computed for the entire dataset, and replicated across all
        facets.
    log_x: boolean (default `False`)
        If `True`, the x-axis is log-scaled in cartesian coordinates.
    log_y: boolean (default `False`)
        If `True`, the y-axis is log-scaled in cartesian coordinates.
    range_x: list of two numbers
        If provided, overrides auto-scaling on the x-axis in cartesian
        coordinates.
    range_y: list of two numbers
        If provided, overrides auto-scaling on the y-axis in cartesian
        coordinates.
    render_mode: str
        One of `'auto'`, `'svg'` or `'webgl'`, default `'auto'` Controls the
        browser API used to draw marks. `'svg`' is appropriate for figures of
        less than 1000 data points, and will allow for fully-vectorized output.
        `'webgl'` is likely necessary for acceptable performance above 1000
        points but rasterizes part of the output.  `'auto'` uses heuristics to
        choose the mode.
    title: str
        The figure title.
    template: str or dict or plotly.graph_objects.layout.Template instance
        The figure template name (must be a key in plotly.io.templates) or
        definition.
    width: int (default `None`)
        The figure width in pixels.
    height: int (default `None`)
        The figure height in pixels.
    
    Returns
    -------
        plotly.graph_objects.Figure