The Limitation of Accuracy of Classification Models

Author

zenggyu

Published

2018-06-07

Abstract
Explains why the accuracy of predictions provided by a classification can be low, even if it can perfectly estimate the probability of an outcome.

It didn’t occur to me that even if a classification model can perfectly estimate the probability of an outcome, the accuracy of the prediction can still be low. This post explains the phenomenon.

Consider a model which can provide the estimated probability of an outcome based on the inputs from predictors. Suppose for any observation, the predicted outcome is generated randomly according to the estimated probability, then the accuracy of the prediction can be expressed as a function of the true probability p1 and the estimated probability p2:

A(p1,p2)=p1p2+(1p1)(1p2)=(2p11)p2+1p1(1)

The derivative of the function A() of the variable p2 is:

Ap2=2p11(2)

According to the derivative, as p2 increases, A(p1,p2): increases monotonically if p1>0.5; decreases monotonically if p1<0.5; remains constant if p1=0.5. Therefore, given any p1>0.5, A(p1,p2) is only maximized when p2=1; while given any p1<0.5, A(p1,p2) is only maximized when p2=0. In any case, the maximization is not reached when p1=p2.

1 Unless p2=0 or 1.

visualizes the relationship among A(p1,p2), p1 and p2.

Figure 1: Probability and accuracy

Formula (3) shows how the maximized accuracy changes with different value of p1. Under almost all circumstances, the accuracy does not reach 1; and the closer p1 is to 0.5, the worse the accuracy. This helps to understand the limitation of classification algorithms.

Amax={p1p1>0.50.5p1=0.51p1p1<0.5(3)


Some useful notes on the code that makes :

library(tidyverse)
library(lattice)

dat <- crossing(p1 = seq(0, 1, 0.025),
                p2 = seq(0, 1, 0.025)) %>%
  mutate(accuracy = p1 * p2 + (1 - p1) * (1 - p2))

walk(seq(15, 75, 15), function(z) {
  png(sprintf("%03i.png", z))
  wireframe(accuracy ~ p1 + p2, data = dat,
                 par.settings = list(axis.line = list(col = 'transparent')), scales = list(arrows = FALSE, col = "black"), # Remove the border, but keep the axis ticks.
                 screen = list(z = z, x = -60), # Rotate the screen (default: z = 40, x = -60). The `manipulate::manipulate()` function is a useful tool to that helps to inspect the plot from different directions.
                 drape = T, col.regions = colorRampPalette(c("#5e4fa2", "#3288bd", "#66c2a5", "#abdda4", "#e6f598"))(100), colorkey = F, # Coloring.
                 aspect = c(1, 1) # The ratios of y/x axis and z/x axis.
            ) %>%
    print()
  dev.off()
})

# A few more notes.
# The following parameters may be useful when there are multiple panels:
# `as.table = TRUE`
# `strip = strip.custom(strip.names = T)`
# `layout = c(2, 3)`

system("convert -delay 100 *.png figure_01.gif")
# The `convert` program is a member of the ImageMagick suite of tools. This command should be executed on a Linux operating system.
# The `-delay` option controls how many ticks to wait until the next frame is shown; there are 100 ticks in a second.