3 min read

Scatter plot with multiple group

In many situations, it becomes essential to have a plot with multiple categories. For example, colour the scatter plot according to gender and have two different regression line for each of them.

Lets scatter the some points using data from mtcars, available default in R. Continuing the discussing from this post, where we had plotted mile per gallon (mpg) vs displacement (disp). Here we will create a separate regression line and colour the points accordingly based on number of cylinder (cyl) in the cars. In this article also, I will use there plotting system – base graphics, lattice plot and ggplot2.

First, lets fit our linear model with number of cylinder as a categorical variable.

mtcars <- within(mtcars, cyl <- as.factor(cyl))
mdl <- lm(mpg ~ disp * cyl, data = mtcars)
sumry <- summary(mdl)
sumry

Call:
lm(formula = mpg ~ disp * cyl, data = mtcars)

Residuals:
    Min      1Q  Median      3Q     Max 
-3.4766 -1.8101 -0.2297  1.3523  5.0208 

Coefficients:
             Estimate Std. Error t value Pr(>|t|)    
(Intercept)  40.87196    3.02012  13.533 2.79e-13 ***
disp         -0.13514    0.02791  -4.842 5.10e-05 ***
cyl6        -21.78997    5.30660  -4.106 0.000354 ***
cyl8        -18.83916    4.61166  -4.085 0.000374 ***
disp:cyl6     0.13875    0.03635   3.817 0.000753 ***
disp:cyl8     0.11551    0.02955   3.909 0.000592 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 2.372 on 26 degrees of freedom
Multiple R-squared:  0.8701,    Adjusted R-squared:  0.8452 
F-statistic: 34.84 on 5 and 26 DF,  p-value: 9.968e-11

Lets create a (intercept, slope) pair for each level of cyl and create separate equation for them.

cf <- round(coef(mdl), 3)
eq1 <- c(cf[1], cf[2])
eq2 <- c(cf[1] + cf[3], cf[2] + cf[5])
eq3 <- c(cf[1] + cf[4], cf[2] + cf[6])
eq.fn <- function(eq) {
  paste0("mpg = ", eq[1],
         ifelse(eq[2] < 0, " - ", " + "),
         abs(eq[2]), " disp")
}
eqn <- sapply(list(eq1, eq2, eq3), eq.fn)
rsq.info <- c(paste0("R^2: ", round(sumry[["r.squared"]], 2)),
              paste0("adj. R^2: ", round(sumry[["adj.r.squared"]], 2)))
with(mtcars, {
  car::scatterplot(disp, mpg, groups = cyl, smooth = F,
                   xlab = "Displacement",
                   ylab = "Mile per Gallon",
                   main = "Mile per gallon vs displacement")
})
op <- par(family = "monospace")
legend("topright", text.col = 1:3, col = 1:3, box.lty = 0,
       lty = 1, lwd = 1, legend = eqn)
legend("bottomleft", legend = rsq.info)

par(op)
library(lattice)
lm.panel <- function(x, y, ...) {
  panel.xyplot(x, y, ...)
  panel.abline(eq1, col = 1, lwd = 2, lty = 2)
  panel.abline(eq2, col = 2, lwd = 2, lty = 2)
  panel.abline(eq3, col = 3, lwd = 2, lty = 2)
  panel.text(max(x), max(y), pos = 2,
             fontfamily = "monospace",
             label = paste(eqn, collapse = "\n"))
  panel.text(min(x), min(y), pos = 4,
             fontfamily = "monospace",
             label = paste(rsq.info, collapse = "\n"))
}
xyplot(mpg ~ disp, data = mtcars, groups = cyl,
       auto.key = list(columns = 3, cex = 0.8,
                       title = "Cylinder"),
       col = 1:3, xlab = "Displacement",
       ylab = "Mile per gallon",
       panel = lm.panel)

library(ggplot2)
## Equation table for displaying them on the plot
eqn.df <- with(mtcars, {
  data.frame(
    cyl = sort(unique(cyl)), mpg = max(mpg) - c(1:3),
    disp = max(disp), eqn = eqn
  )
})
ggplot(mtcars, aes(disp, mpg, color = cyl)) +
  geom_point() +
  theme_bw(base_size = 14) +
  geom_smooth(method = "lm", se = FALSE) +
  geom_text(data = eqn.df, aes(label = eqn),
            hjust = 1, family = "mono") +
  annotate(geom = "text",
           x = min(mtcars$disp), y = min(mtcars$mpg),
           family = "mono", hjust = 0,
           label = paste0(rsq.info, collapse = "\n"))