Chapter 6 Classification

6.1 Introduction to Classification

In classification, our goal is to assign each observation in the test dataset to one of a number of pre-specified categories. We do so using information from the observed predictor variables (or ‘features’), based on a classification rule derived using the training data. Applications of classification analysis are manifold, ranging from biological taxonomy (the assignment organisms to species) to diagnosis of disease, to the assessment of credit risk.

As with the prediction problem, in practice we will have training and test datasets. For the training data we will have a record of both the features and the true class membership for each case. The training data can then be used to create a method for linking class membership to the features. In the test data we will have just the features. We can then apply the classification method developed on the training data to allocate each test case to a group.

As we did when discussing prediction, we will a times look at validation datasets (i.e. artificial test datasets) in which the true class memberships are available. This is for illustrative purposes, helping us to compare and contrast various classification technique.

As an example of a classification problem, recall the Italian wine data that we met in Exercise 1 of Laboratory 4. Thinking of this as a training dataset, we display in Figure 6.1 a scatterplot of the percentage alcohol content of the wine against its hue, with the data plotted using different colour plotting symbols for each of the three cultivars. These three groups are quite well separated on just the plotted variables, suggesting that we should be able to derive a classification method with a high degree of accuracy when using information from all the variables.

Plot of percentage alcohol content against hue for a sample of Italian wines. The coloured plotting symbols distinguish the three different cultivars for the wines.

Figure 6.1: Plot of percentage alcohol content against hue for a sample of Italian wines. The coloured plotting symbols distinguish the three different cultivars for the wines.

If we think of class membership as being defined by a categorical target variable, then this mirrors the prediction problem studied except for the nature of the target. It should therefore be no surprise that some of the techniques that we met in the context of prediction can also be applied to the classification problem with only minor modifications. Tree-based methods and neural networks are two prime examples that will be considered in this chapter. We shall also a examine some specialized methods for classification, including logistic regression, linear discriminant analysis, kernel discriminant analysis and naive Bayes classifiers49.

In general, no classification method is likely to be 100% accurate. This is because there will always be inherent uncertainty on class membership based upon the observed features. For example, suppose that we try to classify human sex based on height and weight. A person of 185cm and 85kg is likely to be male, but is not certain to be so. The best we can do is to try and evaluate classification probabilities. These are the probabilities of membership of the possible classes given the observed features: \({\textsf P}(\mbox{class} | \mbox{features})\).

We will usually assign each test case to the class with highest classification probability. However, this is not always the case, since in some situations there are differing costs associated between different types of mistake. For example, consider a dermatologist classifying each skin lesion as benign or malignant. Mistakenly classifying a lesion that is benign to the malignant class will result in unnecessary treatment, but that is a less critical error than classifying a malignant lesion as benign, when the delay in treatment might be fatal. As a consequence, a dermatologist may act as if a lesion is malignant (for example, by excising it) even if \({\textsf P}(\mbox{malignant} | \mbox{features}) = 0.1\).

Our broad approach to classification must be to estimate the classification probabilities. There are two ways in which we might do this: generational and discriminatory. Discriminatory (or conditional) methods seek to estimate the classification probabilities directly. Generational methods compute the classification problems by first estimating related quantities, following a statistical theory for classification which we now discuss.

6.1.1 Statistical Theory for Classification

In classification we assume that each observation belongs to one of a finite number of classes, which we will label \(1, 2, \ldots, C\). We will make each classification based on the observed vector \(\boldsymbol{x}= (x_1, x_2, \ldots, x_p)\) of variables (features). Let the joint probability function50 of \(\boldsymbol{x}\) for individuals in class \(j\) be denoted by \(f_j\).

Now, suppose that we have some prior ideas about the probabilities of each class; denote these by \(\pi_1, \ldots, \pi_C\). If the training data are a random sample then we can use the relative proportions to provide these prior probabilities – that is, \(\pi_j = n_j/n\) where \(n_j\) is the number of training observations in class \(j\), and \(n\) is the total number of training observations. This is the default method of estimating the priors for most classifiers implemented in R. In other situations we might have exogenous information about the class probabilities. Sometimes we will have no useful information, when the best course may be to simply set them equal: \(\pi_j = 1/C\) for all \(j=1,\ldots,C\).

Suppose for now that the probability functions \(f_1, \ldots, f_C\) and prior probabilities are all known. We observe an individual with feature vector \(\boldsymbol{x}\). Which class is this individual most likely to belong to? To answer this, we must compute the conditional probabilities \({\textsf P}(j | \boldsymbol{x})\) for \(j=1,\ldots,C\). The quantity \({\textsf P}(j | \boldsymbol{x})\) is the probability that the individual belongs to class \(j\) given that they have feature vector \(\boldsymbol{x}\). Now, by Bayes’ Theorem51 we have \[\begin{equation} {\textsf P}(j | \boldsymbol{x}) = \frac{ \pi_j f_j(\boldsymbol{x}) }{f(\boldsymbol{x})}~~~~~~~~~~(j=1,\ldots,C) \tag{6.1} \end{equation}\] where \(f(\boldsymbol{x}) = \sum_j \pi_j f_j(\boldsymbol{x})\). Because the denominator on the right-hand side of Equation (6.1) does not depend upon the class \(j\), it follows that the most probable class is the one for which \(\pi_j f_j(\boldsymbol{x})\) is largest.

Example 6.1 Some Bayesian Reassurance for a Hypochondriac

A nervous patient has a set of blood test results \(\boldsymbol{x}\). He surfs the Internet, and finds that the probability that a normal healthy individual will return such results is \(0.01\), but that for people with Erdheim-Chester disease such test results would be reasonably common, appearing with probability \(0.5\). However, his GP (an avid Bayesian!) seems very unconcerned. She notes that this is a classification problem with two categories: diseased (class 1) and non-diseased (class 2). It is well known that Erdheim-Chester disease is a very rare condition, so that \(\pi_1 = 0.000001\) (i.e. 1 in a million), and hence \(\pi_2 = 1 - \pi_1 = 0.999999\). From the test results the GP knows that \(f_1(\boldsymbol{x}) = 0.5\) and \(f_2(\boldsymbol{x}) = 0.01\), but she computes the conditional probability of being diseased given the test results as \[{\textsf P}(1 | \boldsymbol{x}) = \frac{ 0.000001 \times 0.5 }{0.000001 \times 0.5 + 0.999999 \times 0.01} \approx 0.00005.\] It is extremely unlikely that the patient’s worries are well founded.

6.1.2 Misclassification Rates and Confusion Matrices

In order to implement classification based on the posterior probabilities in Equation (6.1) we need the prior probabilities \(\pi_j\) and the probability functions \(f_j\) for all classes \(j=1,\ldots,C\). Obtaining values for \(\pi_1,\ldots,\pi_C\) will not usually be a problem, as discussed above. The challenge is to obtain estimates of \(f_1,\ldots,f_C\) from the training data. We shall examine a number of options.

  • Assume that \(f_1,\ldots,f_C\) are normal probability density functions (for continuous random variables). This leads to the method of linear discriminant analysis (Section 6.2).

  • Estimate \(f_1,\ldots,f_C\) using kernel density estimation. This leads to the method of kernel discriminant analysis (Section 6.3).

  • Factorize the joint probability functions to facilitate their estimation, leading to naive Bayesian classification (Section 6.4).

Given this variety of methods, as well as the more generic tree-based and neural network classifiers, which should we prefer? We will usually want to select the method with the lowest misclassification rate. The misclassification rate is simply the proportion of observations that are assigned to the wrong class. A more detailed analysis of the performance of each classifier can be obtained from the error matrix or confusion matrix, which cross-tabulates true and predicted classes. In some cases the impact of certain types of misclassification may be far greater than others, in which case one my wish to consider a weighted error rate. This can occur, for example, in medical testing, where the consequences of false positives may be somewhat unpleasant (e.g. unnecessary biopsies) but the consequence of false negatives can be fatal (e.g. failure to detect a disease while it is still treatable).

Example 6.2 Comparison of Tests for Tuberculosis

The Mantoux and Heaf tests are both used to detect tuberculosis in patients. Suppose that they are each applied to 1000 subjects suspected of having the disease, giving the following confusion tables. Note that disease status is the true class here (positive or negative for tuberculosis) and the test result (positive or negative) is the predicted class.

Table 6.1: Performance of the Mantoux (left) and Heaf (right) tests for detecting tuberculosis
Disease + Disease -
Test + 50 50
Test - 1 899
Disease + Disease -
Test + 40 30
Test - 11 929

The overall misclassification rate for the Mantoux test is \(51/1000 = 5.1\%\), while the over misclassification rate for the Heaf test is \(41/1000 = 4.1\%\). However, if we assign severity weights 1 to false positives and 10 to false negatives, then the weighed error is \((50+1 \times 10)/1000 = 0.06\) for the Mantoux test, which is rather better than the figure of \((30+11\times 10)/1000 = 0.14\) for the Heaf test.

In practice we will need to estimate the misclassification rate for each classifier that we use. Recall from the previous chapter that assessing predictive accuracy (in terms of mean squared error) based on predictions on the training data led to an overly optimistic assessment. The methods enjoyed ‘home ground advantage’. The same issue applies in classification. If we develop a classification rule on the training data, then it will tend to perform especially well on those data, and in particular better than it would on real test data. We therefore need a more independent method for estimation misclassification rates.

If we have an independent validation dataset for which the correct class assignments are know, then misclassification rates can be assessed based on performance thereon. As we saw previously, this could be achieve by splitting the original training dataset into a smaller (new) training set and a separate validation set, typically in a ratio of around 3:1. Alternatively, we can employ cross-validation techniques, following along the same lines as in Section 4.4.5 (but with the squared error performance measure replaced by the misclassification rate).

6.2 Linear Discriminant Analysis

6.2.1 The Method in Theory

Linear discriminant analysis (often abbreviated to LDA) is applicable when all the features are quantitative. The method relies critically on the assumption that that the joint distributions of the features is multivariate normal, with a covariance matrix that is the same from class to class. The thing which distinguishes the groups is inter-class variation of the vector of feature means.

Consider a very simply case where there are just \(C=2\) classes, and just \(p=2\) numerical features, \(x_1\) and \(x_2\). Suppose that these features are statistically independent, and that the standard deviation of each is equal to one (irrespective of class). Finally, let both features have mean zero in group 1, and let both features have mean 1 in group 2. In other words, if \(\mu_{ij}\) is the mean for feature \(i\) in group \(j\) then \(\mu_{11} = \mu_{21} = 0\), and \(\mu_{12} = \mu_{22} = 1\). Now, using standard results for the multivariate normal distribution52, we have \[\begin{aligned} f_j(\boldsymbol{x}) &= \frac{1}{2 \pi} \exp \left \{ -\frac {(x_1 - \mu_{1j})^2}{2} \right \} \exp \left \{ -\frac{(x_2 - \mu_{2j})^2}{2} \right \}\\ &= \left \{ \begin{array}{ll} \frac{1}{2 \pi} \exp \left \{ -\frac{x_1^2}{2} - \frac{x_2^2}{2} \right \} & \mbox{for class $j=1$}\\ \frac{1}{2 \pi} \exp \left \{ -\frac{(x_1-1)^2}{2} - \frac{(x_2-1)^2}{2} \right \} & \mbox{for class $j=2$} \end{array} \right .\end{aligned}\] If the prior class probabilities are equal (i.e. \(\pi_1 = \pi_2 = 0.5\)) then we will assign an observation \(\boldsymbol{x}= (x_1, x_2)\) to class \(1\) if and only if \[\begin{aligned} f_1(\boldsymbol{x}) &> f_2(\boldsymbol{x}) \\ \Rightarrow\quad \exp\{ (-x_1^2/2 - x_2^2/2)\} &> \exp\{ (-(x_1-1)^2/2 - (x_2-1)^2/2)\} \\ \Rightarrow\quad -x_1^2/2 - x_2^2/2 &> -(x_1-1)^2/2 - (x_2-1)^2/2 \\ \Rightarrow\quad x_1^2 + x_2^2 &< x_1^2 - 2 x_1 + 1 + x_2^2 - 2 x_2 +1 \\ \Rightarrow\quad 0 &< - 2 x_1 + 1 - 2 x_2 +1 \\ \Rightarrow\quad x_1 + x_2 &< 1\\\end{aligned}\] Time to take stock. We have just shown that given an observation \((x_1,x_2)\), we should assign it to class 1 if \(x_1 + x_2 < 1\), and to class 2 if \(x_1 + x_2 \ge 1\). (In theory the possibility of equality can be ignored, because for truly continuous data we will get precise equality with probability zero.) In other words, \(x_1 + x_2 < 1\) is a decision rule that tell us where to classify the observation (class 1 if true, class 2 if false). This can be plotted graphically, as in Figure 6.2. Notice that the boundary of the regions is given by the straight line \(x_1 + x_2 = 1\), hence the nomenclature linear discriminant analysis. The combination \(x_1 + x_2\) is referred to as a linear discriminant or a discriminant variable because of its role in defining the decision rule.

Plot of decision regions for a linear discriminant analysis with two groups. Any observation in the grey region will be assigned to class 1; any observation in the cyan region will be assigned to class 2.

Figure 6.2: Plot of decision regions for a linear discriminant analysis with two groups. Any observation in the grey region will be assigned to class 1; any observation in the cyan region will be assigned to class 2.

The result is very intuitive. Each of observation will be assigned to the class for which the mean is closest. While we have demonstrated this result for two class problems, it extends naturally to cases with three or more classes. In such situations the feature space will be divided into multiple regions using line segment boundaries. See Figure 6.3 for example. If there are more than two features then the boundaries of the decision regions will be a defined by the intersection of hyperplanes rather than lines.

Plot of decision regions for a linear discriminant analysis with three groups.

Figure 6.3: Plot of decision regions for a linear discriminant analysis with three groups.

6.2.2 Implementation in R

Linear discriminant analysis can be implemented in R using the lda function. This is part of the MASS library, which must be pre-loaded. The syntax for lda is of the following form:

lda(y ~ x1 + x2, data=mydata, CV=FALSE, prior=c(0.5,0.5))

Here y in the formula is the true class, and x1 and x2 are features. The argument data allows specification of a data frame in the usual manner. The optional argument CV is logical, indicating whether or not a cross-validation estimate of the posterior probabilities \(\{{\textsf P}(j | \boldsymbol{x}) \}\) is required. The default is CV=FALSE. Finally, prior probabilities can be specified as a vector valued argument prior. The default is to use the class proportions in the data.

Predicted classes from an LDA model are obtained by the predict function in the usual kind of way.

While this is relatively consistent with what we’ve seen in the past, we’ll be utilising the tidymodels framework instead, using parsnip with the discrim_linear() function. We can then use the yardstick to evaluate performance.

Example 6.3 Prediction for the Wine Data Using LDA

We start by splitting the wine data into training and test data (which we shall refer to throughout this chapter). The wine data is in truth a relatively straightforward one for classification, so we make things a little more challenging by leaving only 100 observations in the training set (with the remaining 78 in the validation dataset). These operations are performed in the R code below, where we first convert the target variable Cultivar to a factor.

library(rsample)
wine <- read_csv("../data/wine.csv") |>
  mutate(Cultivar = factor(Cultivar))

set.seed(2021)
split <- initial_split(wine, prop=100/178)
wine.train <- training(split)
wine.test <- testing(split)
wine.train
#> # A tibble: 100 × 14
#>    Cultivar Alcohol Malic   Ash Alc.Ash    Mg Phenols  Flav Non.Flav Proan
#>    <fct>      <dbl> <dbl> <dbl>   <dbl> <dbl>   <dbl> <dbl>    <dbl> <dbl>
#>  1 C          12.51  1.24  2.25    17.5    85    2     0.58     0.6   1.25
#>  2 C          13.73  4.36  2.26    22.5    88    1.28  0.47     0.52  1.15
#>  3 C          13.71  5.65  2.45    20.5    95    1.68  0.61     0.52  1.06
#>  4 C          12.84  2.96  2.61    24     101    2.32  0.6      0.53  0.81
#>  5 B          12.21  1.19  1.75    16.8   151    1.85  1.28     0.14  2.5 
#>  6 C          13.17  2.59  2.37    20     120    1.65  0.68     0.53  1.46
#>  7 B          12.6   1.34  1.9     18.5    88    1.45  1.36     0.29  1.35
#>  8 B          11.61  1.35  2.7     20      94    2.74  2.92     0.29  2.49
#>  9 B          13.34  0.94  2.36    17     110    2.53  1.3      0.55  0.42
#> 10 B          12.34  2.45  2.46    21      98    2.56  2.11     0.34  1.31
#> # ℹ 90 more rows
#> # ℹ 4 more variables: Colour <dbl>, Hue <dbl>, ODRatio <dbl>, Proline <dbl>

We initially consider prediction with LDA using all the variables.

library(discrim)
library(yardstick)
wine.lda1 <- discrim_linear() |> fit(Cultivar ~ ., data=wine.train)
wine.lda1
#> parsnip model object
#> 
#> Call:
#> lda(Cultivar ~ ., data = data)
#> 
#> Prior probabilities of groups:
#>    A    B    C 
#> 0.27 0.40 0.33 
#> 
#> Group means:
#>    Alcohol    Malic      Ash  Alc.Ash        Mg  Phenols      Flav  Non.Flav
#> A 13.72333 1.977407 2.509630 17.52593 105.96296 2.831111 3.0022222 0.2981481
#> B 12.24400 2.103250 2.229750 20.07250  95.50000 2.262500 2.0822500 0.3525000
#> C 13.11576 3.340909 2.418788 21.28788  99.54545 1.688182 0.7481818 0.4542424
#>      Proan   Colour       Hue  ODRatio   Proline
#> A 1.826296 5.444444 1.0566667 3.162593 1079.6296
#> B 1.690500 2.924250 1.0540000 2.922000  521.7750
#> C 1.136061 7.250606 0.6809091 1.709697  626.5152
#> 
#> Coefficients of linear discriminants:
#>                   LD1          LD2
#> Alcohol  -0.010317451  1.175068679
#> Malic     0.182260164  0.104519875
#> Ash       0.138964990  2.573625788
#> Alc.Ash   0.038864464 -0.160828615
#> Mg       -0.003958634  0.004833826
#> Phenols   0.767868607 -0.086488945
#> Flav     -2.569432136  0.050198155
#> Non.Flav -2.044924752 -0.504567571
#> Proan     0.358496497 -0.403501573
#> Colour    0.298461150  0.132267168
#> Hue      -1.850873874 -1.976866286
#> ODRatio  -1.288883722  0.078974327
#> Proline  -0.001686943  0.002915649
#> 
#> Proportion of trace:
#>    LD1    LD2 
#> 0.7512 0.2488
wine.lda1.pred <- wine.lda1 |> augment(new_data=wine.test)
wine.lda1.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 32  1  0
#>          B  0 28  0
#>          C  0  2 15

Note the following:

  • We have first loaded the discrim library, an add-on package for parsnip, in order to access the discrim_linear function.

  • The dot on the right-hand side of a model formula has the usual meaning (of all variables in the data frame barring the response on the left-hand side).

  • The fit object wine.lda1 is a parsnip model which then prints the output from lda(). Much of the resulting output is self explanatory. We note that the table of coefficients of the linear discriminants is a matrix which transforms standardized versions of the original observations to the linear discriminants.

  • We use augment() to add the .pred_class, .pred_A, .pred_B and .pred_C columns to the wine.test data. The first of these are the “hard” class predictions (based on maximum posterior probabilities) while the last 3 are the aforementioned posterior probabilities.

  • The yardstick function conf_mat tabulates the predicted classes against the true classes; i.e. the confusion matrix. We see that the overall misclassification rate is \(3/78 \approx 3.8\%\).

We now consider trying to perform LDA using just the variables Alcohol, Color and Hue.

wine.lda2 <- discrim_linear() |> fit(Cultivar ~ Alcohol + Colour + Hue, data=wine.train)
wine.lda2.pred <- wine.lda2 |> augment(new_data=wine.test)
wine.lda2.pred |> conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 30  3  1
#>          B  2 27  1
#>          C  0  1 13

We get a worse misclassification rate in this case: \(8/78 \approx 10.3\%\). Nonetheless, we used many less variables to develop the classification model. The question remains, might we improve on the initial (full) model by omitting some of the variables? This is another instance of a bias-variance trade-off. In principle we could perform some kind of variable selection algorithm to address this matter, but it turns out that there are better ‘dimension reducing’ approaches (though these are beyond the scope of this paper).

Finally, we have used the training sample relative frequencies to define the prior probabilities in both of these LDA classifiers. As you will see from the output for wine.lda1, this gave \((\pi_1,\pi_2,\pi_3) = (0.27, 0.4, 0.33)\). We can instead specify our own prior probabilities through the prior argument to lda(). For example, suppose that we thought that in practice all three cultivars are equally prevalent. Then we could perform classification using the prior \((\pi_1,\pi_2,\pi_3) = (1/3, 1/3, 1/3)\) instead. To pass the prior information through from discrim_linear() we would use the set_engine() function:

spec_lda_equal_priors <- discrim_linear() |> set_engine("MASS", prior = c(1,1,1)/3)

Note that we wouldn’t expect the classifications to change much as the prior is only a little different.

6.3 Kernel Discriminant Analysis

While linear discriminant analysis can work well in many applications, it can also fail spectacularly when the boundaries between the groups are highly non-linear. Consider, for instance, the (artificial) data set displayed in Figure 6.4. The separation between the red and black observations is clear to the eye, but linear discriminant analysis cannot handle such a scenario.

An artificial dataset for which linear discriminant analysis will completely fail.

Figure 6.4: An artificial dataset for which linear discriminant analysis will completely fail.

A (far) more flexible approach than LDA is to estimate the probability functions \(f_1,\ldots,f_C\) using kernel density estimation.

Recall that kernel density estimation in one dimension is where we estimate a distribution \(f\) using the mean of kernel functions (typically guassian) centered at each observation.

\[ \hat{f}(x) = \frac{1}{n}\sum_{i=1}^n \frac{1}{h}K(\frac{x - x_i}{h}). \]

Here, \(h\) is the bandwidth of the kernel density estimator, which essentially controls the amount of smoothing: a small value of \(h\) will result in a noisy distribution, while a large value of \(h\) will result in something that may be overly smooth and miss fine detail.

We can extend this concept readily to higher dimensions, though the bandwidth \(h\) is replaced by a square matrix of bandwidths of size \(d\) if the dimension is \(d\).

In the case where we have continuous predictor variables, we can use the training data from each class to construct kernel estimates \(\hat f_1, \ldots, \hat f_C\). Substituting these estimates in Equation (6.1) leads to kernel discriminant analysis.

While kernel discriminant analysis wins over LDA in terms of flexibility, the method can be somewhat unreliable when there are large numbers of predictor variables. Indeed, kernel density estimation will typically prove difficult in more than 3 dimensions, unless there are very large numbers of observations available.

Kernel discriminant analysis is implemented in R using the kda function within the ks library. To keep things consistent, we will use the tidykda package53 which implements a wrapper to coerce kernel discriminant analysis into the parsnip framework.

The bandwidths are automatically found using a plugin-estimator, so all we need to do is provide at most 3 predictors and utilise the discrim_kernel() function. This can take some time to run for large datasets.

Example 6.4 Kernel Discriminant Analysis for the Wine Data

We apply kernel discriminant analysis to the Italian wine data. Specifically, we attempt to classify Cultivar based on the three predictors Alcohol, Colour and Hue.

library(tidykda)
wine.kda <- discrim_kernel() |> fit(Cultivar ~ Alcohol + Colour + Hue, data=wine.train)
wine.kda.pred <- wine.kda |> augment(new_data=wine.test)
wine.kda.pred |> conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 29  6  2
#>          B  0 23  0
#>          C  3  2 13

Points to note:

  • We use only 3 predictors, as this is the most that discrim_kernel will accept.

  • The error rate using kernel discriminant analysis is \(13/78 \approx 16.7\%\). This is worse than we obtained using LDA. Hence, while linear boundaries may not be entirely satisfactory for distinguishing between classes, the errors thus arising are of a little less importance than those due to the extra complexity inherent in the kernel approach for this example.

6.4 Naive Bayes Classifiers

As we noted above, kernel discriminant analysis does not usually work at all well when there are many predictor variables. Indeed, the R implementation of the method can cope with no more than \(p=6\). This is a reflection of the curse of dimensionality, which in essence states that many statistical estimation problems become much more difficult in high dimensions.

This idea can be illustrated most easily if we consider the situation in which the predictor variables are categorical. Specifically, imagine that we have \(p=10\) factors, each on 3 levels. To specify \(f_j(\boldsymbol{x})\) we must compute the probability for any combination the factors. However, there are \(3^{10} = 59049\) different combinations. That’s a lot! Even if we had a large training set with \(10000\) observations per class, it would still be the case that we would have no training observations for more than \(86\%\) of the possible factorial combinations. How then can we hope to estimate the probabilities thereof?

In naive Bayes classification this problem is addressed by making the assumption that all the predictors are statistical independent within each class. This means that we can factorize the probability functions by \[\begin{equation} f_j(\boldsymbol{x}) = f_j(x_1) f_j(x_2) \cdots f_j(x_p) \tag{6.2} \end{equation}\] where \(f_j(x_i)\) denotes the marginal probability function for predictor \(x_i\) in class \(j\). To appreciate the extent to which this simplifies the problem, let us return to out example with 10 predictive factors each with 3 levels. In order to estimate the marginal probability distribution for any single factor we need only estimate 2 probabilities54. This in turn means that we need only estimate \(10\times 2=20\) probabilities in total to fully specify the probability function in Equation (6.2).

Example 6.5 Coke or Pepsi?

A sample of 100 people are questioned as to whether they prefer Coke or Pepsi (the target classification). A total of 60 prefer Coke and 40 prefer Pepsi. The predictive factors are sex and age category (younger than 20, or 20 and older) of person. The following results were recorded.

Table 6.2: Preference of Pepsi (left) and Coke (right) by sex and age
Younger Older
Male 20 10
Female 10 20
Younger Older
Male 10 20
Female 5 5

To specify the probability distribution \(f_1\) for coke drinkers we must compute probabilities for each combination of predictive factors. We can do so using the observed proportions. Under the full Bayesian model we have \[\hat f_1(\{Male,Older\}) = \tfrac{10}{60} = \tfrac{1}{6}\] for example. Using naive Bayes on the other hand, we obtain \[\hat f_1(\{Male,Older\}) \approx \hat f_1(Male) \hat f_1(Older) = \tfrac{30}{60} \times \tfrac{30}{60} = \tfrac{1}{4}.\]

Naturally there is a price to pay for such a huge reduction in model complexity. The assumption of independence of predictors within each class is strong, and will almost certainly be incorrect in practice. Nonetheless, such a simplification may still be a good idea, based on the kind of bias-variance trade off arguments with which you should now be familiar.

Naive Bayes classification in R is implemented in a range of different packages, such as the e1071 package (with function naiveBayes()), the naivebayes package (with function naivebayes()) and the klaR package (with function NaiveBayes()).

As before, we will be using the tidymodels framework, where discrim provides the model tidier naive_Bayes(). We’ll utilise the naivebayes engine. Predictions will then be made in the same way (i.e. using predict() or augment()).

The implementation of the naive Bayes classifier can take both quantitative and categorical predictors. The marginal distributions (i.e. the \(f_j(x_i)\) mentioned above) for the numerical variables are estimated using kernel density estimators (or, alternatively univariate normal distributions), while marginal probabilities for the categorical variables are estimated using class-specific proportions.

Example 6.6 Naive Bayes Classification for the Wine Data

We apply naive Bayes classification to the Italian wine data, first using all the predictor variables and then using just Alcohol, Colour and Hue. Lastly, we’ll use a normal distribution in place of the kernel density estimates.

spec_nB <- naive_Bayes(engine="naivebayes")
wine.nB.1 <- spec_nB |>
  fit(Cultivar ~ ., data=wine.train)
wine.nB.1.pred <- wine.nB.1 |>
  augment(new_data=wine.test)
wine.nB.1.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 30  2  0
#>          B  2 26  0
#>          C  0  3 15
wine.nB.2 <- spec_nB |>
  fit(Cultivar ~ Alcohol + Colour + Hue, data=wine.train)
wine.nB.2.pred <- wine.nB.2 |>
  augment(new_data=wine.test)
wine.nB.2.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 32  4  1
#>          B  0 25  1
#>          C  0  2 13
wine.nB.3 <- spec_nB |>
  set_engine(engine="naivebayes", usekernel=FALSE) |>
  fit(Cultivar ~ Alcohol + Colour + Hue, data=wine.train)
wine.nB.3
#> parsnip model object
#> 
#> 
#> ================================== Naive Bayes ================================== 
#>  
#>  Call: 
#> naive_bayes.default(x = maybe_data_frame(x), y = y, usekernel = ~FALSE)
#> 
#> --------------------------------------------------------------------------------- 
#>  
#> Laplace smoothing: 0
#> 
#> --------------------------------------------------------------------------------- 
#>  
#>  A priori probabilities: 
#> 
#>    A    B    C 
#> 0.27 0.40 0.33 
#> 
#> --------------------------------------------------------------------------------- 
#>  
#>  Tables: 
#> 
#> --------------------------------------------------------------------------------- 
#>  ::: Alcohol (Gaussian) 
#> --------------------------------------------------------------------------------- 
#>        
#> Alcohol          A          B          C
#>    mean 13.7233333 12.2440000 13.1157576
#>    sd    0.4691646  0.4963188  0.5086258
#> 
#> --------------------------------------------------------------------------------- 
#>  ::: Colour (Gaussian) 
#> --------------------------------------------------------------------------------- 
#>       
#> Colour         A         B         C
#>   mean 5.4444444 2.9242500 7.2506060
#>   sd   1.1854546 0.7474191 2.1912438
#> 
#> --------------------------------------------------------------------------------- 
#>  ::: Hue (Gaussian) 
#> --------------------------------------------------------------------------------- 
#>       
#> Hue            A         B         C
#>   mean 1.0566667 1.0540000 0.6809091
#>   sd   0.1271885 0.1816562 0.1231961
#> 
#> ---------------------------------------------------------------------------------
wine.nB.3.pred <- wine.nB.3 |>
  augment(new_data=wine.test)
wine.nB.3.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 30  4  1
#>          B  2 24  1
#>          C  0  3 13

Points to note:

  • The output for the fitted objects (e.g. wine.nB.3) aren’t strongly useful. There are tables per predictor with a summary of the model fit - kernel densities or the mean/sd of the normal distributions as seen for wine.nB.3.

  • The naive Bayes classifier implement with all predictors has an error rate of \(7/78 \approx 9\%\). With just the predictors Alcohol, Colour and Hue this rises to \(8/78 \approx 10.3\%\). Using normal distributions rather than kernel densities for the marginal distributions is worse again, at \(11/78 \approx 14.1\%\)

6.5 Logistic regression

Logistic regression is a type of discriminatory classifier. This, along with the other methods of classification in the next section, attempt to learn the probability \(P(j|\boldsymbol{x})\) directly, rather than generating it from the joint probability function within each class and prior probabilities via Bayes’ Theorem.

Logistic regression learns \(P(j|\boldsymbol{x})\) directly by assuming a parameteric form for the probability. It is one of a class of methods known as generalised linear models (GLMs), and is thus closely related to linear regression. The generalisation involved is that the distribution of the response variable \(y\) need not be a normal distribution. In particular, logistic regression is primarily concerned with a binary (i.e. 0 and 1) response variable, whose expected value will be the probability of a response \(y=1\). Thus, it is useful for classifying data into two groups, which we can associate with the two outcomes – say \(y=0\) codes for group 1, and \(y=1\) for group 2. Extensions to more than two groups will be given in section 6.5.3.

It is linear in the sense that a transformation links a linear combination of predictors \[\theta(\boldsymbol{x}) = \beta_0 + \beta_1 x_1 + \cdots + \beta_p x_p,\] with the probability \(p(\boldsymbol{x}) = {\textsf P}(y=1 | \boldsymbol{x})\)55. This transformation is the sigmoid function56, also known as the logistic function: \[p(\boldsymbol{x}) = \frac{1}{1+e^{-\theta(\boldsymbol{x})}}.\] The logistic function ensures that the expected outcome is between 0 and 1, see the left-hand panel of Figure 6.5.

The logistic and logit functions.

Figure 6.5: The logistic and logit functions.

The inverse of the logistic function is the logit function \[\theta(\boldsymbol{x}) = {\textsf {logit}}(p(\boldsymbol{x})) = \log\left(\frac{p(\boldsymbol{x})}{1-p(\boldsymbol{x})}\right).\] See the right-hand panel of Figure 6.5. The logit may also be interpreted as the log odds, where the odds are defined as \(\tfrac{p(\boldsymbol{x})}{1-p(\boldsymbol{x})}\). Odds are an alternate scale for interpreting the chance of something happening. Rather than writing the probability of it occurring, we instead give a ratio of the probability that it occurs to the probability that it does not occur. As an example, the odds of something that occurs twice in five are 2:3 – you expect it to occur twice for every 3 times it does not occur. Notice that, while proportions are restricted between 0 and 1, the odds may take any positive number, and hence the log odds may take any real number. This makes the log odds appropriate for a linear combination of predictors which, if we allow any possible values for the predictors, will also produce all real numbers.

The formulation of logistic regression is therefore: \[\begin{aligned} y_i &\sim \mathsf{Bernoulli}(p(\boldsymbol{x}_i))\\ {\textsf {logit}}(p(\boldsymbol{x}_i)) &= \beta_0 + \beta_1x_{i1} + \cdots+ \beta_p{x_ik}.\end{aligned}\] where

  • \(y_i\) is the binary outcome for observation \(i\);

  • \(\boldsymbol{x}_i = (x_{i1}, \ldots, x_{ip})^\mathsf{T}\) is the vector of predictor variables for the \(i\)th observation;

  • \(p(\boldsymbol{x}_i) = {\textsf P}(y_i = 1 | \boldsymbol{x}_i)\);

  • \(\beta_1, \ldots, \beta_p\) are unknown regression coefficients.

Note that, as with linear regression, we can use either quantitative variables or factors as predictors. When factors are used, we code them as 0/1 dummy variables in the same way as is done for linear regression.

There is then a natural way of interpreting the fitted coefficients for factors in terms of odds. Suppose we have a factor with one of its levels is coded as the binary variable \(x_1\), and let \(x_2, \ldots, x_p\) be other predictors (including any other levels of the factor). Then if \(x_1\) is a 1, the log odds will be given by \[{\textsf {logit}}(p(\boldsymbol{x})) = \beta_0 + \beta_1 + \sum_{j=2}^p \beta_j x_j.\] If \(x_1\) is a 0, the log odds is given by \[{\textsf {logit}}(p(\boldsymbol{x})) = \beta_0 + \sum_{j=2}^p \beta_jx_j.\] Thus, the ratio of the odds, or odds ratio associated with \(x_1\) is given by \[\begin{aligned} \mathsf{OR}(x_1) &= \frac{e^{\beta_0 + \beta_1 + \sum_{j=2}^p \beta_jx_j}}{e^{\beta_0 + \sum_{j=2}^p\beta_jx_j}}\\ & = e^{\beta_1}.\end{aligned}\] An odds ratio of \(x_1\) greater than one means the odds that \(y=1\) increase when \(x_1\) is 1; an odds ratio less than one means the odds that \(y=1\) decreases when \(x_1\) is 1. Thus, the odds increase if \(\beta_1 > 0\) and decrease if \(\beta_1 < 0\).

6.5.1 Fitting Logistic Regression Models

Unlike linear regression, we cannot estimate the unknown regression coefficients \(\beta_1, \ldots, \beta_p\) via a closed form expression that minimises the sum of squared errors. Instead, we must find the regression coefficients by an iterative procedure which attempts to maximise the likelihood of the data57 The likelihood is simply the probability of the data regarded as a function of the parameters.

For a particular observation \(y_i\), the likelihood for logistic regression is \[\begin{aligned} {\textsf P}(y_i|\boldsymbol{x}_i,\boldsymbol{\beta}) &= \left\{\begin{array}{ll}p & \textrm{if}\quad y_i = 1\\ 1-p & \textrm{if}\quad y_i = 0,\end{array}\right.\\ &= \left\{\begin{array}{ll}\frac{1}{1 + e^{-\beta_0 - \sum_j \beta_j x_{ij}}} & \textrm{if}\quad y_i = 1\\ \\ \frac{e^{-\beta_0 - \sum_j \beta_j x_{ij}}}{1 + e^{-\beta_0 - \sum_j \beta_j x_{ij}}} & \textrm{if}\quad y_i = 0,\end{array}\right.\\ & = \frac{e^{(-\beta_0 -\sum_j \beta_j x_{ij})(1-y_i)}}{1 + e^{-\beta_0 -\sum_j \beta_i x_{ij}}}.\end{aligned}\] Thus, the likelihood of the full data arising, assuming the data are independent, is \[\begin{equation} L(\boldsymbol{\beta}) = P(\boldsymbol{y}|\boldsymbol{x},\boldsymbol{\beta}) = \prod_{i=1}^n \frac{e^{(-\beta_0 - \sum_j \beta_j x_{ij})(1-y_i)}}{1 + e^{-\beta_0 - \sum_j \beta_j x_{ij}}}. \tag{6.3} \end{equation}\] The parameter vector \(\boldsymbol{\beta}\) that maximises equation (6.3) will be the one that is most likely to produce the data we observe, known as the maximum likelihood estimate 58. In practice, we typically maximise the log-likelihood function instead \[\begin{aligned} \ell(\boldsymbol{\beta}) = \log L(\boldsymbol{\beta}) &= \sum_{i=1}^n \log\left(\frac{e^{(-\beta_0 - \sum_j \beta_j x_{ij})(1-y_i)}}{1 + e^{-\beta_0 - \sum_j \beta_j x_{ij}}}\right)\\ & = \sum_{i=1}^n \log(e^{(-\beta_0 - \sum_j \beta_j x_{ij})(1-y_i)}) - \log(1 + e^{-\beta_0 - \sum_j \beta_j x_{ij}})\\ & = \sum_{i=1}^n (-\beta_0 - \sum_j \beta_j x_{ij})(1-y_i) - \log(1 + e^{-\beta_0 - \sum_j \beta_j x_{ij}}).\end{aligned}\] There is no closed form solution for the maximum of \(\ell(\boldsymbol{\beta})\), so the maximum likelihood estimator is typically found using numerical techniques such as Newton’s method.

Logistic regression models may be fit in R using the glm function. The basic syntax is

glm(formula, family, data)

where

  • formula is the model formula,

  • family describes the error distribution. Specify family=binomial for logistic regression59,

  • data is the data frame to use (optional).

As with linear regression, the fitted model should be assigned to a fitted model object for later use.

The formula is of the same form as for linear regression

y ~ x1 + x2

where y in this case is a binary variable (0/1) and x1 and x2 are predictor variables, either quantitative or factors.

Example 6.7 Logistic regression modelling for Swiss bank notes

We have measurements on 100 forged and 100 genuine Swiss bank notes. On each bank note the size of the bottom margin was recorded, as was the diagonal length of the note (all measurements in mm). We are interested in classifying the notes into forged and genuine classes based on the margin and diagonal measurements. The data has been split into training and validation sets, containing 120 and 80 notes respectively.

swiss.train <- read_csv("../data/swiss-train.csv") |>
  mutate(type = factor(type))
swiss.test <- read_csv("../data/swiss-test.csv") |>
  mutate(type = factor(type))
glimpse(swiss.train)
#> Rows: 120
#> Columns: 3
#> $ margin   <dbl> 9.7, 9.5, 10.0, 11.2, 10.7, 7.4, 11.6, 7.7, 9.3, 8.0, 8.2, 11…
#> $ diagonal <dbl> 140.6, 141.1, 139.4, 139.9, 139.5, 141.8, 140.2, 141.8, 138.3…
#> $ type     <fct> forged, genuine, forged, forged, forged, genuine, forged, gen…
swiss.glm.1 <- glm(type ~ margin, family=binomial, data=swiss.train)
summary(swiss.glm.1)
#> 
#> Call:
#> glm(formula = type ~ margin, family = binomial, data = swiss.train)
#> 
#> Coefficients:
#>             Estimate Std. Error z value Pr(>|z|)    
#> (Intercept)  19.9677     3.5250   5.665 1.47e-08 ***
#> margin       -2.1582     0.3842  -5.618 1.93e-08 ***
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> 
#> (Dispersion parameter for binomial family taken to be 1)
#> 
#>     Null deviance: 165.822  on 119  degrees of freedom
#> Residual deviance:  73.596  on 118  degrees of freedom
#> AIC: 77.596
#> 
#> Number of Fisher Scoring iterations: 6
swiss.glm.2 <- glm(type ~., family=binomial, data=swiss.train)
step(swiss.glm.2)
#> Start:  AIC=13.26
#> type ~ margin + diagonal
#> 
#>            Df Deviance    AIC
#> <none>           7.260 13.260
#> - margin    1   12.651 16.651
#> - diagonal  1   73.596 77.596
#> 
#> Call:  glm(formula = type ~ margin + diagonal, family = binomial, data = swiss.train)
#> 
#> Coefficients:
#> (Intercept)       margin     diagonal  
#>    -985.258       -4.920        7.324  
#> 
#> Degrees of Freedom: 119 Total (i.e. Null);  117 Residual
#> Null Deviance:       165.8 
#> Residual Deviance: 7.26  AIC: 13.26

Some comments on this R snippet.

  • The type variable has two levels, forged and genuine in that order. Thus, forged will be taken as the class represented by \(y=0\), and genuine will be the class represented by \(y=1\) in the logistic regression.

  • The glm line is fitting a logistic model with type, a factor with two possible values as the response. For swiss.glm.1 there is just a single predictor, margin. For swiss.glm.2 both margin and diagonal are predictors.

  • For swiss.glm.1 we see that the coefficient of margin is negative, so \({\textsf P}(y=1 | \texttt{margin})\) (i.e. probability of genuine note) decreases as margin increases.

  • For swiss.glm.1 linear predictor is \(\theta(\texttt{margin}) = \beta_0 + \beta_1 \texttt{margin}\). The maximum likelihood estimates for the parameters are \(\hat \beta_0 = 19.9677\) and \(\hat \beta_1 = -2.1582\).

  • The Deviance mentioned in the output is the logistic regression analogue of sum of squares in linear regression.

  • As with linear regression, there is the question of what selection of predictor variables provide a model with best bias-variance trade-off. Again, mirroring our discussion of linear regression for prediction, we can use backwards variable selection implemented to minimze AIC as a way of finding a reasonable set of predictor variables. In this particular case we see that the model swiss.glm.2 has a lower AIC than either of the models that can be obtained by removing a single predictor, so the backwards variable selection algorithm (implemented using step) terminates immediately.

6.5.2 Logistic Regression as a classifier

When using logistic regression as a classifier, for a given value of the predictors \(\boldsymbol{x}\) we would classify as a 1 if \({\textsf P}(Y=1|\boldsymbol{x}) > {\textsf P}(Y=0|\boldsymbol{x})\) and as 0 otherwise. i.e. we classify as a 1 if \[\frac{{\textsf P}(Y=1|\boldsymbol{x})}{{\textsf P}(Y=0|\boldsymbol{x})} > 1.\] Now, \({\textsf P}(Y=1|\boldsymbol{x}) = p(\boldsymbol{x})\), and thus \({\textsf P}(Y=0|\boldsymbol{x}) = 1-p(\boldsymbol{x})\). Hence, we classify as a 1 if \[\begin{aligned} & & \frac{P(Y=1|\boldsymbol{x})}{P(Y=0|\boldsymbol{x})} &> 1,\\ & \iff & \frac{p(\boldsymbol{x})}{1-p(\boldsymbol{x})} &> 1,\\ & \iff & \log(\frac{p(\boldsymbol{x})}{1-p(\boldsymbol{x})}) &> 0,\\ & \iff & {\textsf {logit}}(p(\boldsymbol{x})) &> 0.\\ & \iff & \beta_0 + \beta_1x_1 + \cdots \beta_px_p &> 0.\end{aligned}\] This gives a simple formula for classifying: evaluate the linear predictor and assign to class 1 if positive and class 0 if negative.

Example 6.8 Classifying the Swiss Bank Notes data

Recall from the previous example that the model swiss.glm.1 had \(\theta(\texttt{margin}) = \beta_0 + \beta_1 \texttt{margin}\) with maximum likelihood estimates \(\hat \beta_0 = 19.9677\) and \(\hat \beta_1 = -2.1582\). It follows that the fitted model (based on training data) has linear predictor \(\hat \theta(\texttt{margin}) = 19.9677 -2.1582~\texttt{margin}\). We would hence classify a test case to genuine notes (coded by \(y=1\)) if and only if \[\begin{aligned} 19.9677 -2.1582~\texttt{margin} \ge 0 &\Leftrightarrow& \texttt{margin} \le \frac{19.9677}{2.1582} \\ &\Leftrightarrow& \texttt{margin} \le 9.252~~~.\end{aligned}\] A quick exploration of the data should persuade you that this is a natural result: most of the forged notes are detectable by their surprisingly large bottom margins.

The predict command on a fitted model object from glm by default returns predictions from the linear predictor, rather than on the original scale of the data. This is a bit fiddly - we could optionally assign the positive predictions to the second class (class 1), and the negative productions to the first class (class 0).

Alternatively, we can just use the tidymodels framework with the logistic_reg() function in place of glm().

Example 6.9 More on Classifying the Swiss Bank Notes data

We continue our example on the Swiss bank notes by using the output from glm() to compute the class probabilities ourselves:

library(broom)
swiss.glm.2 |> augment(newdata=swiss.test) |>
  mutate(.pred_genuine = 1/(1 + exp(-.fitted)))
#> # A tibble: 80 × 5
#>    margin diagonal type     .fitted .pred_genuine
#>     <dbl>    <dbl> <fct>      <dbl>         <dbl>
#>  1    8.7    137.8 forged  -18.767      7.0707e-9
#>  2    8.9    142.3 genuine  13.208      1.0000e+0
#>  3   10.3    139.7 forged  -12.723      2.9821e-6
#>  4   10.2    139.6 forged  -12.963      2.3447e-6
#>  5   10.1    139.2 forged  -15.401      2.0483e-7
#>  6   10.6    139.7 forged  -14.199      6.8157e-7
#>  7    9.1    141.5 genuine   6.3648     9.9828e-1
#>  8   10.4    140.3 forged   -8.8203     1.4769e-4
#>  9    7.9    140.9 genuine   7.8741     9.9962e-1
#> 10    8      139.6 genuine  -2.1396     1.0531e-1
#> # ℹ 70 more rows

The .fitted entry contains the predicted logits for the test cases. Hence for test case 1 (the predictors for which we denote by the vector \(\boldsymbol{x}_0\)), \[\theta(\boldsymbol{x}_0) = \log \left ( \frac{p(\boldsymbol{x}_0)}{1 - p(\boldsymbol{x}_0)} \right ) = -18.77 .\] It follows that \[{\textsf P}(y=1 | \boldsymbol{x}_0) = 1/(1+e^{-\theta(\boldsymbol{x}_0)}) = 1/(1+e^{18.77}) \approx 0.\]

Thus, each test case will be assigned to the genuine class (coded \(y=1\)) if and only if it has a posterior probability \(p(\boldsymbol{x}) > 0.5\), or equivalently if and only if its linear predictor \(\theta(\boldsymbol{x}_0) > 0\).

Ofcourse, the tidymodels framework takes care of this for us, so if we instead utilise our existing framework knowledge, we can get the same predictions as we’re used to:

swiss.lr.2 <- logistic_reg() |>
  fit(type ~ ., data=swiss.train)
swiss.lr.2.pred <- swiss.lr.2 |>
  augment(new_data=swiss.test)
swiss.lr.2.pred
#> # A tibble: 80 × 6
#>    .pred_class .pred_forged .pred_genuine margin diagonal type   
#>    <fct>              <dbl>         <dbl>  <dbl>    <dbl> <fct>  
#>  1 forged      1.0000           7.0707e-9    8.7    137.8 forged 
#>  2 genuine     0.0000018354     1.0000e+0    8.9    142.3 genuine
#>  3 forged      1.0000           2.9821e-6   10.3    139.7 forged 
#>  4 forged      1.0000           2.3447e-6   10.2    139.6 forged 
#>  5 forged      1.0000           2.0483e-7   10.1    139.2 forged 
#>  6 forged      1.0000           6.8157e-7   10.6    139.7 forged 
#>  7 genuine     0.0017181        9.9828e-1    9.1    141.5 genuine
#>  8 forged      0.99985          1.4769e-4   10.4    140.3 forged 
#>  9 genuine     0.00038034       9.9962e-1    7.9    140.9 genuine
#> 10 forged      0.89469          1.0531e-1    8      139.6 genuine
#> # ℹ 70 more rows

As always, augment() augments the test data swiss.test with the predicted class .pred_class as well as the predicted probabilies of each class (.pred_forged and .pred_genuine). We can thus directly cross tabulate the usual way:

swiss.lr.2.pred |>
  conf_mat(truth=type, estimate=.pred_class)
#>           Truth
#> Prediction forged genuine
#>    forged      36       1
#>    genuine      0      43

Some comments:

  • The confusion matrix tabulates the predicted class versus the true class from the swiss.valid set. This shows that all but one of the bank notes has been correctly classified, a misclassification rate of just \(1/80\).

Logistic regression, particularly when being applied to data sets that can be relatively easily classified, can end up creating a perfect classifier based on the training data. In this case there is no unique solution for the maximum likelihood estimate of the coefficients – there are an infinite number of possibile coefficients that will produce the same perfect predictions. We illustrate this with a slight change of the Swiss bank notes data.

Example 6.10 Classification of modified Swiss Bank Notes

Consider the following R snippet, where we have changed a single observation in the swiss.train data set slightly. The plot is given in Figure 6.6.

swiss2.train <- swiss.train |>
  mutate(diagonal = if_else(row_number() == 103, 140.2, diagonal))

ggplot(data=swiss2.train) +
  geom_point(mapping = aes(x=margin, y=diagonal, col=type)) +
  scale_colour_manual(guide = 'none', values=c('red', 'black'))
Forged (black) and genuine (red) Swiss bank notes from Example \@ref(exm:swiss5)

Figure 6.6: Forged (black) and genuine (red) Swiss bank notes from Example 6.10

swiss2.lr <- logistic_reg() |> fit(type ~ ., data=swiss2.train)
#> Warning: glm.fit: algorithm did not converge
#> Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred

Some comments on the R code.

  • glm.fit gives two warning messages, firstly that the algorithm did not converge, and secondly that fitted probabilities that were numerically no different than 0 or 1 occurred.

  • Figure 6.6 suggests there is a line that perfectly separates the data into the two groups. It is this perfect separation based on a linear estimator that causes glm to not find a unique maximum likelihood estimate, and thus fail to converge. Further, the predicted probabilities on the training set are indistinguishable from 0 and 1.

  • In this particular case, the coefficients found from the model lead to a perfect fit of the training data. However, the algorithm used to find the maximum likelihood estimators may fail to converge for other reasons, such as collinearity of the predictors, or a lack of data (no observations) for certain combinations of factors.

6.5.3 Extending Logistic Regression to more than two classes

Logistic regression extends naturally to more than two categories by moving from a single probability of success \(p\) to the probability of each class \(p_c\) for \(c \in 1, \ldots, C\) where \(C\) is the number of classes. We then model \(p_c\) as \[p_c = \left\{ \begin{array}{ll}\frac{\displaystyle 1}{\displaystyle 1 + \sum_{m=2}^C e^{-\beta_{m0} - \sum_{j=1}^k \beta_{mj}x_j}} & \textrm{if } c=1\\ \\ \frac{\displaystyle e^{-\beta_{c0} - \sum_{j=1}^k \beta_{cj}x_j}}{\displaystyle 1 + \sum_{m=2}^C e^{-\beta_{m0} - \sum_{j=1}^k \beta_{mj}x_j}} & \textrm{if } c\in \{2, \ldots, C\}. \end{array}\right.\] Notice that we now have \(C-1\) lots of \(k\) coefficients to estimate, one set for each of the classes excluding one. The last class need not have separate estimators, as the probabilities \(p_c\) must sum to one60.

Such models are typically referred to as multinomial regression models, as they extend the binomial distribution to trials with more than 2 outcomes.

The multinom_reg function in parsnip fits this by default using a neural network via the nnet library. We illustrate this with an analysis of the Italian wine data.

Example 6.11 Multinomial Regression for the Wine Data

We perform classification using the multinom_reg function.

wine.mn <- multinom_reg() |>
  fit(Cultivar ~ ., data=wine.train)
wine.mn |>
  extract_fit_engine()
#> Call:
#> nnet::multinom(formula = Cultivar ~ ., data = data, trace = FALSE)
#> 
#> Coefficients:
#>   (Intercept)   Alcohol      Malic       Ash  Alc.Ash         Mg  Phenols
#> B  279.836344  2.510949 -14.231258 -95.03628 2.485551 0.14497172 53.54385
#> C   -4.627004 23.977218  -9.053405 -26.39729 2.855262 0.09578787 58.79897
#>        Flav  Non.Flav    Proan    Colour       Hue  ODRatio    Proline
#> B -92.06688 106.00528 60.79086 -16.23188 218.29240 20.26225 -0.5097276
#> C -99.62994 -47.82549 35.94083  15.05347 -31.03135 21.60742 -0.4933040
#> 
#> Residual Deviance: 4.412777e-06 
#> AIC: 56
wine.mn.pred <- wine.mn |> augment(new_data=wine.test)
wine.mn.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 32  3  1
#>          B  0 28  0
#>          C  0  0 14

Some comments:

  • The summary for multinom objects gives the values of the coefficients for classes B and C but not for class A, as that is then determined as the probabilities sum to 1.

  • A misclassification rate of \(4/78 \approx 5.1\%\) is similar to that achieved earlier.

We continue by attempting to fit a model containing just the variables Alcohol, Color, and Hue.

wine.mn2 <- multinom_reg() |>
  fit(Cultivar ~ Alcohol + Colour + Hue, data=wine.train)
wine.mn2.pred <- wine.mn2 |> augment(new_data=wine.test)
wine.mn2.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 30  3  1
#>          B  2 23  0
#>          C  0  5 14

We get a worse misclassification rate in this case of \(11/78 \approx 14.1\%\). Nonetheless, this used significantly fewer variables in the model. Might we improve on the full model by omitting some of the variables? We now consider a variable selection algorithm.

6.5.4 Model selection

As with linear regression, model selection is necessary for logistic regression in order to eliminate variables that are only weakly informative (or not informative at all) for classification. Eliminating such a variable will increase the bias of the model but reduce the variance, as we need not estimate a coefficient for those variables. This variance/bias trade-off requires a balance between model fit and model complexity.

The Akaike Information Criterion (AIC) may be used in logistic regression in the same way as it was used in linear regression. Indeed, we can utilise the step function to perform backwards elimination from a full model.

Unfortunately, doing so means moving away from the tidymodels framework, as step() does not know what to do with the wrapped glm or multinom model. The nicest way to do this is to fit the model using multinom and then step. Then, refit the resulting model in the tidymodels framework to benefit from the improved manipulation tools.

Example 6.12 Multinomial model selection for the Wine Data

We use the step function on the full model wine.mn to produce a model with reduced AIC.

wine.mn.refit <- nnet::multinom(Cultivar ~ ., data=wine.train, trace=FALSE)
wine.mn3.refit <- step(wine.mn.refit)
summary(wine.mn3.refit)
#> Call:
#> nnet::multinom(formula = Cultivar ~ Alcohol + Flav + Hue + Proline, 
#>     data = wine.train, trace = FALSE)
#> 
#> Coefficients:
#>   (Intercept)   Alcohol      Flav       Hue     Proline
#> B   245.00191 -17.44058 -12.95288  71.47525 -0.08345303
#> C   -11.09362  15.55627 -42.50984 -73.13526 -0.09439152
#> 
#> Std. Errors:
#>   (Intercept)     Alcohol         Flav          Hue    Proline
#> B 0.264644803 3.238274047 14.715503080 0.2276467814 0.01392838
#> C 0.000605084 0.008377486  0.005430621 0.0003053845 0.19769171
#> 
#> Residual Deviance: 0.3720983 
#> AIC: 20.3721
wine.mn3 <- multinom_reg() |>
  fit(Cultivar ~ Alcohol + Flav + Hue + Proline, data=wine.train)
wine.mn3.pred <- wine.mn3 |> augment(new_data=wine.test)
wine.mn3.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 32  2  0
#>          B  0 27  0
#>          C  0  2 15

Some comments on this R snippet.

  • The step command can be used in the same way as for linear regression, but only if not using the tidymodels framework. Note that we’ve used nnet::multinom() - this means we don’t have to explicitly load the nnet library (it will be implicitly loaded).

  • The final model contained Alcohol, Flav, Hue and Proline as predictors. We then re-fit this model in the tidymodels framework so we can simply use augment() and conf_mat() as always.

  • The final model has a misclassification rate of \(4/78 \approx 5.1\%\) which is the same as the full model, however it contains only 4 predictors.

6.6 Other Methods of Classification

The classifiers that we have considered to date have either been based on Equation (6.1) using various explicit models for the probability functions \(\{f_j\}\) or, in the case of Logistic regression, based on a functional form of the probabilities \(P(j | \boldsymbol{x})\). We now turn to some classification techniques that are more model-free in nature, namely k-nearest neighbour classifiers, classification trees and neural networks. The first of these is designed to work with quantitative predictors only, but the latter pair can also incorporate factors.

6.6.1 k-Nearest Neighbour Classification

The idea behind k-nearest neighbour classification is to find the \(k\) data points in the training set that are most similar to the test observation at hand, and then assign the test case to the most common class for these proximate training data. This general idea should ring some bells – we discussed nearest neighbour imputation for handling missing values.

Two important issues to consider are the scaling of the data, and the choice of \(k\). With regard to the former, note that k-nearest neighbour classification is (usually) based on Euclidean distances between observations. This means that if there is disparity in the scales of the variables, those taking larger values will tend to dominate. Furthermore, it means that the methodology will not be scale invariant. For example, changing the unit of measurement for a variable from metres to centimetres will change the resulting classifier, which is undesirable. If the raw variables are on comparable, natural scales then this may not be a significant problem, but otherwise it will usually be wise to transform the predictors so that they have common means and standard deviations. Fortunately, the default engine for nearest neighbours in the tidymodels framework automatically scales the predictors to have common standard deviations for us.

Turning to the choice of \(k\), the usual considerations of the bias-variance trade-off apply. A small value of \(k\) (e.g. \(k=1\)) will result in a method with considerable variability. A large value of \(k\) will reduce the variability (since the choice of class will now be determined by majority vote over a larger number of observations), but at the expense of introducing bias (since we will be comparing observations with greater disparity in the pattern of predictor values).

The method of k-nearest neighbour classification is implemented in R by several packages. We’ll be utilising the kknn package for this via the parsnip wrapper nearest_neighbour().

Example 6.13 k-Nearest Neighbour Classification for the Wine Data

We perform classification using the k-nearest neighbour method. Note that this automatically scales our predictor variables to have a common variance. We try \(k=1\) and \(k=20\) to compare.

wine.knn1 <- nearest_neighbor(mode="classification", neighbors = 1) |>
  fit(Cultivar ~ ., data=wine.train)
wine.knn1.pred <- wine.knn1 |>
  augment(new_data = wine.test)
wine.knn1.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 32  5  0
#>          B  0 22  0
#>          C  0  4 15
wine.knn2 <- nearest_neighbor(mode="classification", neighbors = 20) |>
  fit(Cultivar ~ ., data=wine.train)
wine.knn2.pred <- wine.knn2 |>
  augment(new_data = wine.test)
wine.knn2.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 32  1  0
#>          B  0 27  0
#>          C  0  3 15

Using \(k=1\) results in a misclassification rate of \(9/78 \approx 11.5\%\), while using \(k=20\) results in a missclassification rate of \(4/78 \approx 5.1\%\).

6.6.2 Classification Trees

The regression trees introduced in Section 4.4 are easily adapted to the problem of classification. We will continue to use binary trees, but with the sequence of splits now chosen so as to most improve classification (as opposed to prediction) performance. This means that we need to modify the sum of squares error criterion that was used to assess splits with regression trees.

A common option (and the default with the R function rpart) is to measure node impurity using the Gini splitting index. The Gini index measures how often a randomly selected observation would be misclassified if the class labels of all observations at the node were randomly reallocated, and so is given by \[\begin{equation} G = \sum_{j=1}^C \hat p_j (1 - \hat p_j) = 1 - \sum_{j=1}^C \hat p_j^2 \tag{6.4} \end{equation}\] where \(\hat p_j\) is the proportion of observations at the node in question which are of class \(j\). Note that if the node is entirely pure, in the sense that only observations from one class are present, then \(G = 0\). On the other hand, if the observations at the node are equally split between all classes, then \(G = 1 - 1/C\) (the maximum value it can take). We will seek splits that minimize Gini index and hence minimize impurity.

An alternative (which can be implemented using rpart) is to instead aim to minimize the cross-entropy information criterion \[\begin{equation} I_E = \sum_{j=1}^C \hat p_j \log(\hat p_j). \tag{6.5} \end{equation}\]

Example 6.14 Growing Classification Trees for the Wine Data

We grow a classification tree for the wine data using the default settings.

wine.rp.1 <- decision_tree(mode="classification") |>
  fit(Cultivar ~ . ,data=wine.train)
wine.rp.1.pred <- wine.rp.1 |>
  augment(new_data=wine.test)
wine.rp.1.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 29  5  0
#>          B  3 24  1
#>          C  0  2 14

This R snippet should largely be self explanatory, as it largely follows the same pattern that should now be familiar. The overall misclassification rate is \(11/78 \approx 14.1\%\).

Classification tree for the wine data.

Figure 6.7: Classification tree for the wine data.

The tree employed is displayed in Figure 6.7. It is very simple, being built on just the variables Colour and Flav. Nonetheless, it is difficult to improve on this. Changing the complexity parameter to cost_complexity=0.0001 (output not shown) results in no change to the model, for example.

6.6.3 Random Forests for Classification

Random Forests may be adapted to the task of classification in the same way as decision trees are adapted: We change the trees to classification trees rather than regression trees, and assemble an ensemble of trees. Prediction is then done by each tree creating it’s individual predited classes for each observation, and voting across all the trees, choosing the most frequently predicted class. Note that in some cases this choice may result in an even split (e.g. 200 trees might say class “A”, 200 trees say class “B”, and 100 trees class “C”) where the assignment may then be implementation-specific, but typically decided through random allocation.

Example 6.15 A Classification Forest for the Wine Data

Our classification forest is fit in the same way as we did for prediction, switching out the mode to be “classification”. We’ll utilise the randomForest package as before.

wine.rf <- rand_forest(mode="classification", engine="randomForest") |>
  fit(Cultivar ~ . ,data=wine.train)
wine.rf.pred <- wine.rf |>
  augment(new_data=wine.test)
wine.rf.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 32  1  0
#>          B  0 27  0
#>          C  0  3 15

The misclassification error is \(4/78 \approx 5.1\%\), similar to many of the other models. These data are not hard to classify!

6.6.4 Neural Networks for Classification

Neural networks for classification work much like neural networks for prediction, although the model fitting criterion is changed from least squares (as it was for prediction). The R implementation nnet handles situations with two classes a little differently to cases with three or more classes. In the former case there is need for just one output node, and model fitting is done with respect to the entropy criterion of Equation (6.5). When the target variable has three or more classes, the number of output nodes equals the number of classes, and fitting is done using conditional maximum likelihood estimation61. In either case, the the output is a classification probability (and so constrained to the interval \([0,1]\)) given by \[\hat p = \psi(\beta_0 + \beta_1 z_1 + \cdots + \beta_M z_M)\] where \(z_1,\ldots,z_M\) are the derived features of the hidden nodes, and \(\psi\) is a sigmoid transformation of the form that we saw previously in Figure 4.14. With the R function nnet this means that the argument linout must be set to FALSE (which is the default, so there is no need to do so explicitly).

Example 6.16 Neural Network Classification for the Wine Data

We will apply a neural network for classifying the Italian wine data. As was indicated in the chapter on prediction, the nnet function does not work well if there is significant variation in the scales of the predictor variables. This is certainly the case with the wine data; for example, in the training set the variable Proline ranges from 290 to 1680, while Non.Flav ranges from 0.14 to 0.66. We therefore use a recipe to scale the data first.

library(recipes)
scaling <- recipe(Cultivar ~ ., data=wine.train) |>
  step_normalize(all_numeric_predictors()) |>
  prep(wine.train)

set.seed(1000)
wine.nn.1 <- mlp(mode="classification", hidden_units = 7) |>
  fit(Cultivar ~ ., data = bake(scaling, wine.train))
wine.nn.1 |> extract_fit_engine() |> pluck("convergence")
#> [1] 0
wine.nn.1.pred <- wine.nn.1 |>
  augment(new_data = bake(scaling, wine.test))
wine.nn.1.pred |>
  conf_mat(truth=Cultivar, estimate=.pred_class)
#>           Truth
#> Prediction  A  B  C
#>          A 32  0  0
#>          B  0 28  0
#>          C  0  3 15

Some notes:

  • Our scaling recipe is setup the usual way, adding a step to normalise all the numeric predictors, and then preparing the recipe using the training data. We then bake() the data prior to passing to fit() or augment()/predict().

  • The line with pluck("convergence") is checking whether the neural network has converged: The fitted neural network object (extracted via extract_fit_engine()) is a list (as are many R model objects) with a ‘convergence’ entry - you can see this from the help for nnet(). The pluck() function is for pulling information out of (possibly nested) lists. In this case we could have also used:

    wine.nn.1 |> pluck("fit", "convergence")

    as the wine.nn.1 object returned from parsnip is itself a list, with the engine fit object in the entry fit. Thus, the pluck here would traverse the first list to find the fit entry, and then traverse within that list to find convergence.

  • The overall model fit here is pretty good, with misclassification rate \(3/78 \approx 3.8\%\), although it should be remembered that results from nnet may vary a little depending on the (random) choice of initial conditions. It follows that running the above code with a different setting of set.seed may give a slightly different error rate.

6.7 High Class Classifiers

We have discussed a variety of techniques for classification. As usual, the ones that require the greatest number of assumptions regarding the data will typically perform best when those assumptions hold, but can do very poorly when they do not. Given that in practice we will not know for certain what assumptions are valid, many data miners prefer to use more flexible tools like classification trees and neural networks.

In these notes we have provided an outline of how each method works. Nonetheless, there is considerable scope for tinkering with the details! In particular, much time and effort continues to be expended in attempting to improve classification accuracy by fractions of one percent. There are frequently competitions on the Internet where teams of researchers seek to obtain the best classifications possible. At present, Kaggle (www.kaggle.com) hosts arguably the most well known website for such competitions.


  1. As the names of the first two listed techniques suggest, statistical classification can also be referred to as discriminant analysis. As we saw earlier, supervised learning is another name for classification, favoured in computer science and related areas.↩︎

  2. We use the somewhat ambiguous term probability function to refer to the probability mass function for discrete random variables, or the probability density function for continuous random variables, as appropriate.↩︎

  3. Bayes’ Theorem is a result of huge importance in Statistics. The result is credited to English mathematician and Presbyterian minister, Reverend Thomas Bayes (1701?-1760).↩︎

  4. If you would like to learn more about the multivariate normal distribution, and the theory of statistics in general (which is both wonderfully elegant and useful) then you should consider taking the paper 161.304.↩︎

  5. Note that this package is not on CRAN but can be installed from Github using the remotes package. See https://github.com/jmarshallnz/tidykda for details.↩︎

  6. We need estimate only 2 of the probabilities because the third is then determined by the constraint that the probabilities in a distribution must sum to one.↩︎

  7. Beware of a possible notational confusion between \(p=p(\boldsymbol{x})\), the probability of a response \(y=1\) given predictors \(\boldsymbol{x}=(x_1, \ldots, x_p)^\mathsf{T}\), and \(p\) the number of predictors. The distinction will usually be obvious by the context. Moreover, we will usually write \(p(\boldsymbol{x})\) for the probability, emphasising its dependence on the predictor variables.↩︎

  8. Recall the sigmoid function was also used for the activation function in neural networks.↩︎

  9. If the mathematics that follows below is beyond you, do not worry: you won’t be examined on this. We include it for the mathematically interested. Nonetheless, be aware that if you want to learn more about statistics and data mining, then it is important to learn enough mathematics so that you can cope with this type of material.↩︎

  10. The maximum likelihood estimate in linear regression is given by the least squares solution, as the log-likelihood reduces to the sum of squared residuals.↩︎

  11. The glm function may be used for more than just logistic regression. It can be used in place of lm for linear regression via family=gaussian, or used for Poisson regression via family=poisson.↩︎

  12. For the same reason, we required only one probability \(p\) for the binary case, as the other probability was necessarily \(1-p\).↩︎

  13. We do not cover the details of conditional maximum likelihood here. An introduction to the method of maximum likelihood estimation is provided in the paper 161.304 Statistical Modelling.↩︎