Loading [MathJax]/jax/output/HTML-CSS/jax.js
library("rstan")
library("ggplot2")
library("bayesplot")
library("dplyr")
library("tidyr")
options(mc.cores = parallel::detectCores())

Multinomial logit model review

Say you have responses on a survey that map to one choice out of K unordered choices. We can model the probability of observing choice i (Pr(Yi=j) like so:

πij=exp(ηij)∑Kj=1exp(ηij) where ηij are the log-odds of response ij. Note you can add a constant to all of the log-odds and get the same probability out. This is what we call a nonidentifiability and it is bad news. We can fix one of the categories to be unity, and this will identify the model. Our likelihood is now

πij=exp(ηij)∑K−1j=1exp(ηij)+1

There is a strong assumption in these models, namely the independence of irrelevant alternatives. It means that the probability of choosing between two alternatives (say, Hillary Clinton vs. Donald Trump) is not dependent on the characteristics of any other alternatives.

These models involve estimating K−1 simultaneous equations and can thus be computationally expensive. We’ll simulate some data to get a feel for fitting these models.

Generate fake data

We haven’t done fake data generation and model checking yet, but this is an important part of modeling in Stan, especially when your models are complex.

We’ll generate from the simplest model for a dataset with five choices.

yn∼Categorical(softmax(α+νn))νn=Xn×β,Xn∈RD,β∈RD×K−1βd,k∼Normal(0,1)αk∼Normal(0,1)

N <- 3000
K <- 3
D <- 3
X <- matrix(rnorm(N * D), N, D)
beta <- cbind(matrix(rnorm((K - 1) * D), D, K - 1),0)
alpha <- c(rnorm(K - 1), 0)
mu <- sweep(x = X %*% beta, MARGIN = 2, STATS = alpha, FUN = '+')
mu_soft <- t(apply(mu, 1, softmax))
y <- sapply(1:N, function(x) rmultinom(1, size = 1, prob = mu_soft[x,]))
y <- apply(y, 2, function(x) which(as.logical(x)))
mod1 <- stan_model(file = 'mnl_constrained.stan')
## In file included from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/BH/include/boost/config.hpp:39:0,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/BH/include/boost/math/tools/config.hpp:13,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/core/var.hpp:7,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/core/gevv_vvv_vari.hpp:5,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/core.hpp:12,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/mat.hpp:4,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math.hpp:4,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/src/stan/model/model_header.hpp:4,
##                  from file72e36f737c64.cpp:8:
## /home/rob/R/x86_64-pc-linux-gnu-library/3.4/BH/include/boost/config/compiler/gcc.hpp:186:0: warning: "BOOST_NO_CXX11_RVALUE_REFERENCES" redefined
##  #  define BOOST_NO_CXX11_RVALUE_REFERENCES
##  ^
## <command-line>:0:0: note: this is the location of the previous definition

Fit our model

stan_dat <- list(y = y, N = N, D = D, X = X)
fit1 <- sampling(mod1, data = stan_dat, iter = 1000, seed =  349596)

Inspect the model: convergence, effective sample size

print(fit1)
## Inference for Stan model: mnl_constrained.
## 4 chains, each with iter=1000; warmup=500; thin=1; 
## post-warmup draws per chain=500, total post-warmup draws=2000.
## 
##                   mean se_mean   sd     2.5%      25%      50%      75%
## beta_raw[1,1]     2.04    0.00 0.10     1.85     1.98     2.04     2.10
## beta_raw[1,2]    -0.69    0.00 0.08    -0.85    -0.75    -0.69    -0.64
## beta_raw[1,3]    -0.59    0.00 0.06    -0.71    -0.63    -0.59    -0.55
## beta_raw[2,1]    -1.05    0.00 0.11    -1.27    -1.12    -1.05    -0.97
## beta_raw[2,2]     2.84    0.00 0.14     2.59     2.74     2.84     2.93
## beta_raw[2,3]    -0.76    0.00 0.07    -0.90    -0.81    -0.76    -0.71
## alpha_raw[1]      1.34    0.00 0.07     1.21     1.30     1.34     1.39
## alpha_raw[2]     -1.23    0.00 0.13    -1.50    -1.32    -1.23    -1.14
## beta[1,1]         2.04    0.00 0.10     1.85     1.98     2.04     2.10
## beta[1,2]        -0.69    0.00 0.08    -0.85    -0.75    -0.69    -0.64
## beta[1,3]        -0.59    0.00 0.06    -0.71    -0.63    -0.59    -0.55
## beta[2,1]        -1.05    0.00 0.11    -1.27    -1.12    -1.05    -0.97
## beta[2,2]         2.84    0.00 0.14     2.59     2.74     2.84     2.93
## beta[2,3]        -0.76    0.00 0.07    -0.90    -0.81    -0.76    -0.71
## beta[3,1]         0.00    0.00 0.00     0.00     0.00     0.00     0.00
## beta[3,2]         0.00    0.00 0.00     0.00     0.00     0.00     0.00
## beta[3,3]         0.00    0.00 0.00     0.00     0.00     0.00     0.00
## alpha[1]          1.34    0.00 0.07     1.21     1.30     1.34     1.39
## alpha[2]         -1.23    0.00 0.13    -1.50    -1.32    -1.23    -1.14
## alpha[3]          0.00    0.00 0.00     0.00     0.00     0.00     0.00
## lp__          -1547.42    0.06 1.93 -1551.88 -1548.54 -1547.12 -1546.03
##                  97.5% n_eff Rhat
## beta_raw[1,1]     2.23  1806    1
## beta_raw[1,2]    -0.54  1791    1
## beta_raw[1,3]    -0.47  2000    1
## beta_raw[2,1]    -0.83  1237    1
## beta_raw[2,2]     3.12  1427    1
## beta_raw[2,3]    -0.62  1714    1
## alpha_raw[1]      1.48  1571    1
## alpha_raw[2]     -0.96  1158    1
## beta[1,1]         2.23  1806    1
## beta[1,2]        -0.54  1791    1
## beta[1,3]        -0.47  2000    1
## beta[2,1]        -0.83  1237    1
## beta[2,2]         3.12  1427    1
## beta[2,3]        -0.62  1714    1
## beta[3,1]         0.00  2000  NaN
## beta[3,2]         0.00  2000  NaN
## beta[3,3]         0.00  2000  NaN
## alpha[1]          1.48  1571    1
## alpha[2]         -0.96  1158    1
## alpha[3]          0.00  2000  NaN
## lp__          -1544.57  1029    1
## 
## Samples were drawn using NUTS(diag_e) at Mon Aug 14 16:16:46 2017.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

Check inferences vs. the true parameters

post_draws_beta <- as.matrix(fit1,pars = c('beta'))
post_draws_alpha <- as.matrix(fit1,pars = c('alpha'))

print(colnames(post_draws_beta))
## [1] "beta[1,1]" "beta[2,1]" "beta[3,1]" "beta[1,2]" "beta[2,2]" "beta[3,2]"
## [7] "beta[1,3]" "beta[2,3]" "beta[3,3]"
print(colnames(post_draws_alpha))
## [1] "alpha[1]" "alpha[2]" "alpha[3]"
true_beta <- c(as.vector(t(beta)))
mcmc_recover_intervals(post_draws_beta, true_beta)

mcmc_recover_intervals(post_draws_alpha, alpha)

Hierarchical multinomial logit

Let’s generate some data from a hierarchical MNL model.

yn∼Categorical(softmax(α+αage[,idx_age[n]]+αedu[,idx_edu[n]]+αeth[,idx_edu[n]]+νn))νn=Xn×βαage[k,]=σagek×ηage[k,],αedu[k,]=σeduk×ηedu[k,],αeth[k,]=σethk×ηeth[k,]αage∈RK−1,J_age, Î±edu∈RK−1,J_edu…σage,σedu,σeth∈RK−1σagek=Ï„agek×σageinter_eqn,σeduk=Ï„eduk×σeduinter_eqn,…βd,k∼Normal(0,1)η∼Normal(0,1)σageinter_eqn,σeduinter_eqn,σethinter_eqn∼Normal+(0,1)Ï„age,Ï„edu,Ï„eth∼Normal+(0,1) ### Generate fake data

set.seed(123)
N <- 1000
K <- 3
D <- 3
J_age <- 5
J_eth <- 4
J_edu <- 5
G <- 3
X <- matrix(rnorm(N * D), N, D)
beta <- cbind(matrix(rnorm((K - 1) * D), D, K - 1),0)
alpha <- c(rnorm(K - 1), 0)
eta_age <- matrix(rnorm((K - 1) * J_age), K - 1, J_age)
eta_eth <- matrix(rnorm((K - 1) * J_eth), K - 1, J_eth)
eta_edu <- matrix(rnorm((K - 1) * J_edu), K - 1, J_edu)
alpha_age <- matrix(0, K, J_age)
alpha_eth <- matrix(0, K, J_eth)
alpha_edu <- matrix(0, K, J_edu)
sigma_age <- abs(rnorm(K - 1))
sigma_eth <- abs(rnorm(K - 1))
sigma_edu <- abs(rnorm(K - 1))
sigma_inter_eqn <- abs(rnorm(G))
for (k in 1:(K - 1)) {
  alpha_age[k,] <- sigma_inter_eqn[1] * sigma_age[k] * eta_age[k,]
  alpha_eth[k,] <- sigma_inter_eqn[2] * sigma_eth[k] * eta_eth[k,]
  alpha_edu[k,] <- sigma_inter_eqn[3] * sigma_edu[k] * eta_edu[k,]
}
alpha_age[K,] <- rep(0, J_age)
alpha_eth[K,] <- rep(0, J_eth)
alpha_edu[K,] <- rep(0, J_edu)

idx_age <- sample(J_age, N, replace = T)
idx_eth <- sample(J_eth, N, replace = T)
idx_edu <- sample(J_edu, N, replace = T)

mu <- sweep(x = X %*% beta, MARGIN = 2, STATS = alpha, FUN = '+')
mu <- t(t(mu) + alpha_age[, idx_age] + alpha_eth[, idx_eth] + alpha_edu[, idx_edu])
mu_soft <- t(apply(mu, 1, softmax))
y <- sapply(1:N, function(x) rmultinom(1, size = 1, prob = mu_soft[x,]))
y <- apply(y, 2, function(x) which(as.logical(x)))

Compile and fit the model

hier_mnl <- stan_model(file = 'mnl_constrained_hier.stan')
## Warning in readLines(file, warn = TRUE): incomplete final line
## found on '/home/rob/Dropbox/class_20170809/multinomial-logit/
## mnl_constrained_hier.stan'
## In file included from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/BH/include/boost/config.hpp:39:0,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/BH/include/boost/math/tools/config.hpp:13,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/core/var.hpp:7,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/core/gevv_vvv_vari.hpp:5,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/core.hpp:12,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/mat.hpp:4,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math.hpp:4,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/src/stan/model/model_header.hpp:4,
##                  from file72e32318ef9.cpp:8:
## /home/rob/R/x86_64-pc-linux-gnu-library/3.4/BH/include/boost/config/compiler/gcc.hpp:186:0: warning: "BOOST_NO_CXX11_RVALUE_REFERENCES" redefined
##  #  define BOOST_NO_CXX11_RVALUE_REFERENCES
##  ^
## <command-line>:0:0: note: this is the location of the previous definition
stan_dat <- list(N = N, K = K, D = D, G = G,
                 J_age = J_age, J_eth = J_eth, J_edu = J_edu,
                 idx_age = idx_age, idx_eth = idx_eth, idx_edu = idx_edu,
                 X = X)

fit_hier_mnl <- sampling(hier_mnl, data = stan_dat, iter = 1000)

Examine the parameters

print(fit_hier_mnl, pars = c('sigma_inter_eqn','sigma_age','sigma_edu','sigma_eth'))
## Inference for Stan model: mnl_constrained_hier.
## 4 chains, each with iter=1000; warmup=500; thin=1; 
## post-warmup draws per chain=500, total post-warmup draws=2000.
## 
##                    mean se_mean   sd 2.5%  25%  50%  75% 97.5% n_eff Rhat
## sigma_inter_eqn[1] 1.01    0.01 0.45 0.37 0.69 0.92 1.26  2.13  1632    1
## sigma_inter_eqn[2] 1.29    0.01 0.49 0.56 0.93 1.23 1.57  2.41  2000    1
## sigma_inter_eqn[3] 0.54    0.01 0.38 0.09 0.27 0.44 0.72  1.56  2000    1
## sigma_age[1]       1.33    0.01 0.54 0.53 0.93 1.25 1.65  2.60  2000    1
## sigma_age[2]       0.26    0.01 0.24 0.01 0.08 0.19 0.36  0.96  1625    1
## sigma_edu[1]       0.70    0.01 0.51 0.03 0.30 0.61 0.98  1.97  2000    1
## sigma_edu[2]       0.63    0.01 0.50 0.03 0.24 0.52 0.89  1.90  2000    1
## sigma_eth[1]       1.56    0.01 0.56 0.69 1.13 1.50 1.91  2.79  2000    1
## sigma_eth[2]       0.25    0.01 0.25 0.01 0.08 0.18 0.35  0.95  1114    1
## 
## Samples were drawn using NUTS(diag_e) at Mon Aug 14 16:18:30 2017.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

Inspect inferences vs. true values

post_draws_beta <- as.matrix(fit_hier_mnl, pars = c('beta'))

true <- as.vector(t(beta))
mcmc_recover_intervals(post_draws_beta, true)

post_draws_age <- as.matrix(fit_hier_mnl, pars = c('alpha_age'))

true <- as.vector(alpha_age)
mcmc_recover_intervals(post_draws_age, true)

post_draws_sigma_inter <- as.matrix(fit_hier_mnl, pars = c('sigma_inter_eqn'))

mcmc_recover_intervals(post_draws_sigma_inter, sigma_inter_eqn)

Binning by cell

If we don’t have individual predictors, we can speed up our model by binning responses by cell. Here’s some individual level fake data that is binned up to the cell level below

set.seed(123)
N <- 30000
K <- 6
D <- 3
J_age <- 5
J_eth <- 4
J_edu <- 5
G <- 3
alpha <- c(rnorm(K - 1), 0)
eta_age <- matrix(rnorm((K - 1) * J_age), K - 1, J_age)
eta_eth <- matrix(rnorm((K - 1) * J_eth), K - 1, J_eth)
eta_edu <- matrix(rnorm((K - 1) * J_edu), K - 1, J_edu)
age <- matrix(0, K, J_age)
eth <- matrix(0, K, J_eth)
edu <- matrix(0, K, J_edu)
sigma_age <- abs(rnorm(K - 1))
sigma_eth <- abs(rnorm(K - 1))
sigma_edu <- abs(rnorm(K - 1))
sigma_inter_eqn <- abs(rnorm(G)) 
for (k in 1:(K - 1)) {
  age[k,] <- sigma_inter_eqn[1] * sigma_age[k] * eta_age[k,]
  eth[k,] <- sigma_inter_eqn[2] * sigma_eth[k] * eta_eth[k,]
  edu[k,] <- sigma_inter_eqn[3] * sigma_edu[k] * eta_edu[k,]
}
age[K,] <- rep(0, J_age)
eth[K,] <- rep(0, J_eth)
edu[K,] <- rep(0, J_edu)

idx_age <- sample(J_age, N, replace = T)
idx_eth <- sample(J_eth, N, replace = T)
idx_edu <- sample(J_edu, N, replace = T)


mu <- t(age[, idx_age] + edu[, idx_edu] + eth[, idx_eth])
mu <- sweep(x = mu, MARGIN = 2, STATS = alpha, FUN = '+')
mu_soft <- t(apply(mu, 1, softmax))
y <- sapply(1:N, function(x) rmultinom(1, size = 1, prob = mu_soft[x,]))
y <- apply(y, 2, function(x) which(as.logical(x)))
mod_df <- data.frame(y = y,
                     idx_age = idx_age,
                     idx_edu = idx_edu,
                     idx_eth = idx_eth)
mod_df %>%
  group_by(
    idx_age, idx_edu, idx_eth, y
  ) %>%
  summarise(
    n = n()
  ) %>%
  group_by(
    idx_age, idx_edu, idx_eth
  ) %>%
  spread(
    key = y, value = n
  ) %>% ungroup() %>%
  mutate_at(
    .cols = vars(`1`,`2`,`3`),
    .funs = funs(if_else(is.na(.), 0L, .))
  ) -> cleaned_dat

N <- nrow(cleaned_dat)
idx_age <- cleaned_dat %>% .$idx_age
idx_eth <- cleaned_dat %>% .$idx_eth
idx_edu <- cleaned_dat %>% .$idx_edu
y <- cleaned_dat %>% select(`1`,`2`,`3`) %>% data.matrix()

stan_dat <- list(N = N, K = K, G = G,
                 J_age = J_age, J_eth = J_eth, J_edu = J_edu,
                 idx_age = idx_age, idx_eth = idx_eth, idx_edu = idx_edu,
                 y = y)

Multinomial likelihood

When we bin by cell, we can use the multinomial likelihood to model the data rather than the categorical likelihood.

We’ll also need to use the softmax function, which is specialized in Stan.

2016 Election example

Let’s load the data and see what it looks like:

polls <- readRDS('election_data.RDS')

Let’s modify our categorical logit model to have a state random intercept for state and get rid of the edu effect. We’ll have 3 equations and one reference category, which will be Jill Stein. We won’t add in yet.

yn∼Multinomial(softmax(α+αage[,idx_age[n]]+αstate[,idx_state[n]]))αage[k,]=σagek×ηage[k,],αstate[k,]=σstatek×ηstate[k,]αage∈RK−1,J_age, Î±state∈RK−1,J_state…σage,σstate∈RK−1σagek=Ï„agek×σageinter_eqn,σstatek=Ï„statek×σstateinter_eqn,…βd,k∼Normal(0,1)η∼Normal(0,1)σageinter_eqn,σstateinter_eqn∼Normal+(0,1)Ï„age,Ï„state∼Normal+(0,1)

polls %>% ungroup() %>%
  mutate(
    idx_age = fac2int(age),
    idx_state = fac2int(state)
  ) -> polls
N <- nrow(polls)
idx_age <- polls %>% .$idx_age
idx_state <- polls %>% .$idx_state
J_age <- n_levels(idx_age)
J_state <- n_levels(idx_state)
y <- polls %>% select(Clinton, Trump, Johnson, Stein) %>% data.matrix()
K <- ncol(y)
elec_mod <- stan_model('mnl_constrained_multinom_lik_election.stan')
## In file included from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/BH/include/boost/config.hpp:39:0,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/BH/include/boost/math/tools/config.hpp:13,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/core/var.hpp:7,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/core/gevv_vvv_vari.hpp:5,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/core.hpp:12,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math/rev/mat.hpp:4,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/stan/math.hpp:4,
##                  from /home/rob/R/x86_64-pc-linux-gnu-library/3.4/StanHeaders/include/src/stan/model/model_header.hpp:4,
##                  from file72e39e6ad96.cpp:8:
## /home/rob/R/x86_64-pc-linux-gnu-library/3.4/BH/include/boost/config/compiler/gcc.hpp:186:0: warning: "BOOST_NO_CXX11_RVALUE_REFERENCES" redefined
##  #  define BOOST_NO_CXX11_RVALUE_REFERENCES
##  ^
## <command-line>:0:0: note: this is the location of the previous definition
stan_dat <- list(N = N, K = K, G = 2,
                 J_age = J_age, J_state = J_state,
                 idx_age = idx_age, idx_state = idx_state,
                 y = y)
elec_fit <- sampling(elec_mod, data = stan_dat, iter = 2000, seed = 3454542, chains = 4, cores = 4, control = list(max_treedepth = 15))

Examine the fit

print(elec_fit, pars = c('sigma_inter_eqn','alpha','sigma_age','sigma_state'))
## Inference for Stan model: mnl_constrained_multinom_lik_election.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##                     mean se_mean   sd  2.5%   25%  50%  75% 97.5% n_eff
## sigma_inter_eqn[1]  1.89    0.01 0.51  1.08  1.53 1.82 2.20  3.03  3328
## sigma_inter_eqn[2]  0.61    0.01 0.30  0.24  0.40 0.54 0.74  1.40  1559
## alpha[1]            0.02    0.02 1.00 -1.98 -0.67 0.04 0.71  1.98  4000
## alpha[2]           -0.01    0.02 1.02 -2.01 -0.72 0.00 0.69  1.95  4000
## alpha[3]            0.01    0.02 0.98 -1.83 -0.67 0.02 0.69  1.88  4000
## alpha[4]            0.00    0.00 0.00  0.00  0.00 0.00 0.00  0.00  4000
## sigma_age[1]        1.60    0.01 0.47  0.83  1.26 1.56 1.89  2.65  3575
## sigma_age[2]        1.56    0.01 0.49  0.78  1.20 1.50 1.85  2.69  3192
## sigma_age[3]        0.63    0.01 0.30  0.24  0.41 0.57 0.78  1.36  1996
## sigma_state[1]      0.67    0.01 0.29  0.23  0.45 0.62 0.84  1.34  2147
## sigma_state[2]      0.96    0.01 0.41  0.33  0.65 0.89 1.20  1.94  2141
## sigma_state[3]      0.85    0.01 0.37  0.29  0.57 0.79 1.07  1.69  2043
##                    Rhat
## sigma_inter_eqn[1]    1
## sigma_inter_eqn[2]    1
## alpha[1]              1
## alpha[2]              1
## alpha[3]              1
## alpha[4]            NaN
## sigma_age[1]          1
## sigma_age[2]          1
## sigma_age[3]          1
## sigma_state[1]        1
## sigma_state[2]        1
## sigma_state[3]        1
## 
## Samples were drawn using NUTS(diag_e) at Mon Aug 14 16:26:25 2017.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).
post_draws <- rstan::extract(elec_fit, pars = c('sigma_inter_eqn','alpha','sigma_age','sigma_state'), permuted = FALSE)

mcmc_areas(post_draws)

How can we expand this model?