Supervised machine learning

We’ll be working with tweets from the four major US presidential candidates’ Twitter accounts.

Let’s create an algorithm to determine what words most distinctively are used by Donald Trump relative to Ted Cruz.

This example was created from workshop materials created by Pablo Barbera.

library(tidyverse)

tweets <- read_csv("../data/pres_tweets.csv")
tweets <- subset(tweets, displayName %in% c("Donald J. Trump","Ted Cruz"))
tweets$trump <- ifelse(tweets$displayName=="Donald J. Trump", 0, 1)

We’ll do some cleaning as well – substituting handles with @. Why? We want to provent overfitting.

tweets$body <- gsub('@[0-9_A-Za-z]+', '@', tweets$body)

Create the dfm and trim it so that only tokens that appear in 10 or more tweets are included.

# updated for quanteda 0.9.9.50
library(quanteda)
twcorpus <- corpus(tweets$body)
twdfm <- dfm(twcorpus, 
             remove=c(stopwords("english"), "t.co", "https", "rt", "amp", "http", "t.c", "can"),
             remove_numbers = TRUE, 
             remove_symbols = TRUE,
             remove_url = TRUE
             )
twdfm <- dfm_trim(twdfm, min_count = 10)

And split the dataset into training and test set. We’ll go with 80% training and 20% set. Note the use of a random seed to make sure our results are replicable.

set.seed(123)
training <- sample(1:nrow(tweets), floor(.80 * nrow(tweets)))
test <- (1:nrow(tweets))[1:nrow(tweets) %in% training == FALSE]

Our first step is to train the classifier using cross-validation. There are many packages in R to run machine learning models. For regularized regression, glmnet is in my opinion the best. It’s much faster than caret or mlr (in my experience at least), and it has cross-validation already built-in, so we don’t need to code it from scratch.

library(glmnet)
require(doMC)
registerDoMC(cores=3)
ridge <- cv.glmnet(twdfm[training,], tweets$trump[training], 
    family="binomial", alpha=0, nfolds=5, parallel=TRUE,
    type.measure="deviance")
plot(ridge)

We can now compute the performance metrics on the test set.

## function to compute accuracy
accuracy <- function(ypred, y){
    tab <- table(ypred, y)
    return(sum(diag(tab))/sum(tab))
}
# function to compute precision
precision <- function(ypred, y){
    tab <- table(ypred, y)
    return((tab[2,2])/(tab[2,1]+tab[2,2]))
}
# function to compute recall
recall <- function(ypred, y){
    tab <- table(ypred, y)
    return(tab[2,2]/(tab[1,2]+tab[2,2]))
}
# computing predicted values
preds <- predict(ridge, twdfm[test,], type="response") > mean(tweets$trump[test])
# confusion matrix
table(preds, tweets$trump[test])
##        
## preds      0    1
##   FALSE 1179   53
##   TRUE    95  678
# performance metrics
accuracy(preds, tweets$trump[test])
## [1] 0.9261845
precision(preds, tweets$trump[test])
## [1] 0.8771022
recall(preds, tweets$trump[test])
## [1] 0.9274966

Something that is often very useful is to look at the actual estiamted coefficients and see which of these have the highest or lowest values:

# from the different values of lambda, let's pick the best one
best.lambda <- which(ridge$lambda==ridge$lambda.min)
beta <- ridge$glmnet.fit$beta[,best.lambda]
head(beta)
##          @       will      trump      great      thank    america 
## -0.6568460 -0.4763454 -1.0495392 -0.9517421 -0.3899949 -0.5320853
## identifying predictive features
df <- data.frame(coef = as.numeric(beta),
                word = names(beta), stringsAsFactors=F)

df <- df[order(df$coef),]
head(df[,c("coef", "word")], n=30)
##           coef                             word
## 812  -2.946053                    entrepreneurs
## 1119 -2.780045                          stopped
## 1291 -2.431947 #trump2016#makeamericagreatagain
## 1195 -2.422461                       #wiprimary
## 1265 -2.413748                         register
## 1289 -2.365488                          muslims
## 13   -2.244013           #makeamericagreatagain
## 1124 -2.224513                              son
## 1365 -2.211081                        louisiana
## 1099 -2.162391                        desperate
## 1215 -2.147966                          changed
## 867  -2.109633                                c
## 1363 -2.061911                            wayne
## 1216 -2.056875                             mark
## 1077 -2.012677                          pushing
## 1207 -1.988745                              hey
## 1387 -1.973182                        legendary
## 733  -1.934180                        respected
## 807  -1.933792                       #teamtrump
## 1156 -1.925063                          network
## 880  -1.923742                           macy's
## 588  -1.899867                         horrible
## 1203 -1.872746                            might
## 1052 -1.845172                     professional
## 1303 -1.832932                           handle
## 853  -1.807507                          tickets
## 1060 -1.800608                        turnberry
## 1331 -1.795529                            store
## 1093 -1.778157                        opponents
## 1162 -1.773498                          boycott
paste(df$word[1:30], collapse=", ")
## [1] "entrepreneurs, stopped, #trump2016#makeamericagreatagain, #wiprimary, register, muslims, #makeamericagreatagain, son, louisiana, desperate, changed, c, wayne, mark, pushing, hey, legendary, respected, #teamtrump, network, macy's, horrible, might, professional, handle, tickets, turnberry, store, opponents, boycott"
df <- df[order(df$coef, decreasing=TRUE),]
head(df[,c("coef", "word")], n=30)
##          coef                    word
## 932  2.833202                   admin
## 54   2.791958             #choosecruz
## 749  2.789096                   heidi
## 1378 2.740852                  prayer
## 1287 2.705626              #cruzcrowd
## 374  2.696764          #atimefortruth
## 1329 2.670115                   human
## 129  2.666027          #caucusforcruz
## 1343 2.497244                  ensure
## 979  2.492870                 houston
## 1040 2.457241               #kateslaw
## 709  2.451043                   y'all
## 916  2.443296           #believeagain
## 1188 2.423531                   catch
## 1100 2.386417    #heritageactionforum
## 274  2.370825               #irandeal
## 991  2.313215             opportunity
## 935  2.270316              #savesaeed
## 767  2.254369             #stand4life
## 312  2.241739 #defendreligiousliberty
## 892  2.232029         #sunshinesummit
## 465  2.223027                 #scotus
## 1286 2.208089               catherine
## 1147 2.205606                 uniting
## 329  2.192077             #fullrepeal
## 1392 2.177566                 donated
## 278  2.122404          #abolishtheirs
## 8    2.104048               #cruzcrew
## 1209 2.090015                  who'll
## 70   2.065730            #cruzcountry
paste(df$word[1:30], collapse=", ")
## [1] "admin, #choosecruz, heidi, prayer, #cruzcrowd, #atimefortruth, human, #caucusforcruz, ensure, houston, #kateslaw, y'all, #believeagain, catch, #heritageactionforum, #irandeal, opportunity, #savesaeed, #stand4life, #defendreligiousliberty, #sunshinesummit, #scotus, catherine, uniting, #fullrepeal, donated, #abolishtheirs, #cruzcrew, who'll, #cruzcountry"