Scatter plot with multiple group

In many situations, it becomes essential to have a plot with multiple categories. For example, colour the scatter plot according to gender and have two different regression line for each of them.

Lets scatter the some points using data from mtcars, available default in R. Continuing the discussing from this post, where we had plotted mile per gallon (mpg) vs displacement (disp). Here we will create a separate regression line and colour the points accordingly based on number of cylinder (cyl) in the cars. In this article also, I will use there plotting system – base graphics, lattice plot and ggplot2.

First, lets fit our linear model with number of cylinder as a categorical variable.

mtcars <- within(mtcars, cyl <- as.factor(cyl))
mdl    <- lm(mpg ~ disp * cyl, data = mtcars)
sumry  <- summary(mdl)
sumry

Call:
lm(formula = mpg ~ disp * cyl, data = mtcars)

Residuals:
         Min           1Q       Median           3Q          Max 
-3.476639349 -1.810070736 -0.229717748  1.352296458  5.020839051 

Coefficients:
                  Estimate     Std. Error  t value   Pr(>|t|)
(Intercept)  40.8719553217   3.0201232876 13.53321 2.7906e-13
disp         -0.1351418146   0.0279090137 -4.84223 5.0965e-05
cyl6        -21.7899679029   5.3066001199 -4.10620 0.00035424
cyl8        -18.8391564080   4.6116619648 -4.08511 0.00037433
disp:cyl6     0.1387469331   0.0363533476  3.81662 0.00075284
disp:cyl8     0.1155077196   0.0295484483  3.90910 0.00059232

Residual standard error: 2.37158072 on 26 degrees of freedom
Multiple R-squared:  0.870134862,   Adjusted R-squared:  0.845160797 
F-statistic: 34.8415391 on 5 and 26 DF,  p-value: 9.96841323e-11

Lets create a (intercept, slope) pair for each level of cyl and create separate equation for them.

cf  <- round(coef(mdl), 3)
eq1 <- c(cf[1], cf[2])
eq2 <- c(cf[1] + cf[3], cf[2] + cf[5])
eq3 <- c(cf[1] + cf[4], cf[2] + cf[6])
eq.fn <- function(eq) {
  paste0("mpg = ", eq[1],
         ifelse(eq[2] < 0, " - ", " + "),
         abs(eq[2]), " disp")
}
eqn <- sapply(list(eq1, eq2, eq3), eq.fn)
rsq.info <- c(paste0("R^2: ", round(sumry[["r.squared"]], 2)),
              paste0("adj. R^2: ", round(sumry[["adj.r.squared"]], 2)))

Plots

Base Graphics

with(mtcars, {
  car::scatterplot(disp, mpg, groups = cyl, smooth = F,
                   xlab = "Displacement",
                   ylab = "Mile per Gallon",
                   main = "Mile per gallon vs displacement")
})
op <- par(family = "monospace")
legend("topright", text.col = 1:3, col = 1:3, box.lty = 0,
       lty = 1, lwd = 1, legend = eqn)
legend("bottomleft", legend = rsq.info)

par(op)

Lattice Plot

library(lattice)
lm.panel <- function(x, y, ...) {
  panel.xyplot(x, y, ...)
  panel.abline(eq1, col = 1, lwd = 2, lty = 2)
  panel.abline(eq2, col = 2, lwd = 2, lty = 2)
  panel.abline(eq3, col = 3, lwd = 2, lty = 2)
  panel.text(max(x), max(y), pos = 2,
             fontfamily = "monospace",
             label = paste(eqn, collapse = "\n"))
  panel.text(min(x), min(y), pos = 4,
             fontfamily = "monospace",
             label = paste(rsq.info, collapse = "\n"))
}
xyplot(mpg ~ disp, data = mtcars, groups = cyl,
       auto.key = list(columns = 3, cex = 0.8,
                       title = "Cylinder"),
       col = 1:3, xlab = "Displacement",
       ylab = "Mile per gallon",
       panel = lm.panel)

ggplot2

library(ggplot2)
## Equation table for displaying them on the plot
eqn.df <- with(mtcars, {
  data.frame(
    cyl  = sort(unique(cyl)),
    mpg  = max(mpg) - c(1:3),
    disp = max(disp), eqn = eqn
  )
})
ggplot(mtcars, aes(disp, mpg, color = cyl)) +
  geom_point() +
  theme_bw(base_size = 14) +
  geom_smooth(method = "lm", se = FALSE) +
  geom_text(data = eqn.df, aes(label = eqn),
            hjust = 1, family = "mono") +
  annotate(geom = "text",
           x = min(mtcars$disp), y = min(mtcars$mpg),
           family = "mono", hjust = 0,
           label = paste0(rsq.info, collapse = "\n"))