library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.3     ✔ readr     2.1.4
## ✔ forcats   1.0.0     ✔ stringr   1.5.0
## ✔ ggplot2   3.4.3     ✔ tibble    3.2.1
## ✔ lubridate 1.9.2     ✔ tidyr     1.3.0
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidytext)
library(tictoc)
library(word2vec)
library(textclean)
library(textstem)
## Loading required package: koRpus.lang.en
## Loading required package: koRpus
## Loading required package: sylly
## For information on available language packages for 'koRpus', run
## 
##   available.koRpus.lang()
## 
## and see ?install.koRpus.lang()
## 
## 
## Attaching package: 'koRpus'
## 
## The following object is masked from 'package:readr':
## 
##     tokenize
library(Rtsne)
library(plotly)
## 
## Attaching package: 'plotly'
## 
## The following object is masked from 'package:ggplot2':
## 
##     last_plot
## 
## The following object is masked from 'package:stats':
## 
##     filter
## 
## The following object is masked from 'package:graphics':
## 
##     layout

Introducción

En este notebook vamos a entrenar un modelo de Word embeddings con datos de Los Simpsons (guiones de varias temporadas, cerca de 150 mil lineas de dialogo para unos 600 episodios) y tratar de entender los resultados.

La idea va a ser tratar de entrenar un modelo word2vec sobre la base de esos diálogos e inferir algunas relaciones solamente con los textos.

Detectando relaciones con Los Simpsons

dialogos <- read_csv('../data/simpsons_script_lines.csv') %>%
        select(id, episode_id, raw_character_text, spoken_words) %>%
        arrange(episode_id, id)
## Warning: One or more parsing issues, call `problems()` on your data frame for details,
## e.g.:
##   dat <- vroom(...)
##   problems(dat)
## Rows: 158271 Columns: 13
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (5): raw_text, raw_character_text, raw_location_text, spoken_words, norm...
## dbl (7): id, episode_id, number, timestamp_in_ms, character_id, location_id,...
## lgl (1): speaking_line
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
head(dialogos)
## # A tibble: 6 × 4
##      id episode_id raw_character_text spoken_words                  
##   <dbl>      <dbl> <chr>              <chr>                         
## 1     1          1 <NA>               <NA>                          
## 2     2          1 <NA>               <NA>                          
## 3     3          1 Marge Simpson      Ooo, careful, Homer.          
## 4     4          1 Homer Simpson      There's no time to be careful.
## 5     5          1 Homer Simpson      We're late.                   
## 6     6          1 <NA>               <NA>

Hay valores faltantes en la primera columna que tienen que ver con dialogo que no es provisto por ninguno de los personajes, por ejemplo, voces en off, etc. Podemos removerlo como sigue:

dialogos <- dialogos %>%
                drop_na()

Ahora, vamos a limpiar el texto. Pero hay una diferencia: en el idioma inglés, el símbolo ’ se usa para contracciones, y si lo removemos sin expandir las contracciones perdemos informacion. Por lo tanto, antes vamos a lidiar con las contracciones usando una función del paquete textclean:

dialogos <- dialogos %>%
        mutate(spoken_words=replace_contraction(spoken_words))

head(dialogos$spoken_words)
## [1] "Ooo, careful, Homer."                                                                                             
## [2] "there is no time to be careful."                                                                                  
## [3] "we are late."                                                                                                     
## [4] "Sorry, Excuse us. Pardon me..."                                                                                   
## [5] "Hey, Norman. how is it going? So you got dragged down here, too... heh, heh. How ya doing, Fred? Excuse me, Fred."
## [6] "Pardon my galoshes."

Ahora si, podemos remover la puntuacion, pasar todo a minúscula, reempalzar dígitos por espacios y transformar caracteres que no sean ascii.

dialogos <- dialogos %>%
        mutate(spoken_words = str_replace_all(spoken_words, "'\\[.*?¿\\]\\%'", " ")) %>%
        mutate(spoken_words = str_replace_all(spoken_words, "[[:punct:]]", " ")) %>%
        mutate(spoken_words = tolower(spoken_words)) %>%
        mutate(spoken_words = str_replace_all(spoken_words, "[[:digit:]]+", "")) %>%
        mutate(spoken_words = replace_non_ascii(spoken_words))

head(dialogos)
## # A tibble: 6 × 4
##      id episode_id raw_character_text spoken_words                              
##   <dbl>      <dbl> <chr>              <chr>                                     
## 1     3          1 Marge Simpson      ooo careful homer                         
## 2     4          1 Homer Simpson      there is no time to be careful            
## 3     5          1 Homer Simpson      we are late                               
## 4     8          1 Marge Simpson      sorry excuse us pardon me                 
## 5     9          1 Homer Simpson      hey norman how is it going so you got dra…
## 6    10          1 Homer Simpson      pardon my galoshes

A continuación, podemos eliminar stopwords. Para ello, vamos a pasar el dataset a formato tidy, luego eliminamos los stopwords con un antijoin y finalmente volvemos a rearmar la columna sin los stopwords:

# Tokenizamos
unigramas <- dialogos %>%
                unnest_tokens(word, spoken_words)

# Cargamos una lista de stowords en inglés
data(stop_words)

# Agregamos algunas stopwords ad-hoc
stop_words <- stop_words %>%
                add_row(word=c('hey','ho'), lexicon=c('adhoc','adhoc'))


# stop_words <- stop_words %>%
#                 mutate(word = lemmatize_words(word))

# Eliminamos stopwords
unigramas <- unigramas %>%
                anti_join(stop_words)
## Joining with `by = join_by(word)`
# Rearmamos la columna
unigramas <- unigramas %>%
        group_by(id, episode_id, raw_character_text) %>%
        summarize(text = str_c(word, collapse = " ")) %>%
        ungroup()
## `summarise()` has grouped output by 'id', 'episode_id'. You can override using
## the `.groups` argument.

Lematizamos:

unigramas <- unigramas %>%
        mutate(text = lemmatize_strings(unigramas$text))

Volvemos a tokenizar pero ahora usando uni y bigramas:

uni_bigramas <- unigramas %>%
                        unnest_ngrams(bigram, text, n_min=1, n=2)

Unimos los bigramas con “_“:

uni_bigramas <- uni_bigramas %>%
        mutate(text = str_replace_all(bigram, " ", "_"))

Y volvemos a rearmar la línea del diálogo.

uni_bigramas <- uni_bigramas %>%
        group_by(id, episode_id, raw_character_text) %>%
        summarize(text = str_c(text, collapse = " ")) %>%
        ungroup()
## `summarise()` has grouped output by 'id', 'episode_id'. You can override using
## the `.groups` argument.
uni_bigramas
## # A tibble: 121,476 × 4
##       id episode_id raw_character_text text                                     
##    <dbl>      <dbl> <chr>              <chr>                                    
##  1     3          1 Marge Simpson      ooo ooo_careful careful careful_homer ho…
##  2     4          1 Homer Simpson      time time_careful careful                
##  3     5          1 Homer Simpson      late                                     
##  4     8          1 Marge Simpson      excuse excuse_pardon pardon              
##  5     9          1 Homer Simpson      norman norman_drag drag drag_heh heh heh…
##  6    10          1 Homer Simpson      pardon pardon_galoshes galoshes          
##  7    11          1 Seymour Skinner    wonderful wonderful_santas santas santas…
##  8    12          1 Marge Simpson      lisa lisa_class class                    
##  9    13          1 JANEY              frohlich frohlich_weihnachten weihnachte…
## 10    14          1 Todd Flanders      meri meri_kurimasu kurimasu kurimasu_hot…
## # ℹ 121,466 more rows

Bueno, el siguiente paso consta de entrenar el modelo word2vec con los datos. Para eso, le pasamos una serie de parametros que estan explicados en el codigo siguiente.

# word2vec_s <- word2vec(x=uni_bigramas$text, # Pasamos la columna con texto
#                        type='skip-gram', #Elegimos el método de ventana
#                        hs=FALSE,
#                        min_count=20, # Ignora palabras cuya frecuencia es menor a esta
#                       window=2, # Fijamos el tamaño de la ventana de contexto
#                       dim=300, # Definimos en cuántas dimensiones queremos el embedding
#                       sample=0.00006, # Umbral para downsamplear palabras muy frecuentes
#                       lr=0.005, #  Tasa de aprendizaje inicial (param. de la red neuronal)
#                       negative=20, # penalidad de palabras poco informaitvas
#                       iter=50, # Iteraciones del modelo
#                       split=c(" \n,.-!?:;/\"#$%&'()*+<=>@[]\\^`{|}~\t\v\f\r",
#                               ".\n?!")
#                       )
# write.word2vec(word2vec_s, '../models/w2v_uni_bigrams3.bin')         

Como el tiempo es tirano, hacemos la gran “Narda Lepes” y sacamos un modelo ya pre-entrenado para abreviar un poco:

word2vec_s <- read.word2vec('../models/w2v_uni_bigrams3.bin')

Detectando relaciones a partir del texto

Ya tenemos todas las palabras en los guiones mapeadas a un espacio de dimension 300. Podemos calcular la similitud semantica entre estas palabras usando la distancia coseno entre ellas.

Esto nos permite hacer queries como por ejemplo, buscar las palabras mas cercanas a alguna en particular.

## $moe
##    term1       term2 similarity rank
## 1    moe     moe_moe  0.9902663    1
## 2    moe  moe_tavern  0.9651136    2
## 3    moe      tavern  0.9573137    3
## 4    moe moe_szyslak  0.9523423    4
## 5    moe     szyslak  0.9511768    5
## 6    moe       flame  0.9373197    6
## 7    moe   bartender  0.8892170    7
## 8    moe         bar  0.8763051    8
## 9    moe       gotta  0.8746590    9
## 10   moe       drink  0.8744557   10

predict(word2vec_s, newdata = c("burn"), type = "nearest", top_n = 30)
## $burn
##    term1           term2 similarity rank
## 1   burn          annual  0.9695159    1
## 2   burn      montgomery  0.9637857    2
## 3   burn           award  0.9609802    3
## 4   burn         charles  0.9532237    4
## 5   burn     outstanding  0.9370724    5
## 6   burn           owner  0.9282656    6
## 7   burn      monty_burn  0.9263046    7
## 8   burn          arrest  0.9227421    8
## 9   burn           monty  0.9181376    9
## 10  burn         speaker  0.9161352   10
## 11  burn           local  0.9159558   11
## 12  burn         uh_burn  0.9145626   12
## 13  burn            germ  0.9121226   13
## 14  burn           court  0.9116307   14
## 15  burn        governor  0.9114665   15
## 16  burn         citizen  0.9111903   16
## 17  burn        champion  0.9108621   17
## 18  burn          estate  0.9099905   18
## 19  burn     billionaire  0.9084716   19
## 20  burn montgomery_burn  0.9083023   20
## 21  burn     millionaire  0.9081165   21
## 22  burn         you_you  0.9077687   22
## 23  burn          behold  0.9072650   23
## 24  burn          kidnap  0.9063010   24
## 25  burn          source  0.9058962   25
## 26  burn       challenge  0.9058764   26
## 27  burn            view  0.9056534   27
## 28  burn           enemy  0.9050730   28
## 29  burn           grand  0.9047724   29
## 30  burn       expensive  0.9046906   30

También podemos jugar con las clásicas analogías: ¿Qué es al reverendo Alegría lo que Marge es a Homero?

wv <- predict(word2vec_s, newdata = c("marge", "reverend_lovejoy", "homer"), type = "embedding")
wv <- wv["marge", ] - wv["homer", ] + wv["reverend_lovejoy", ]

predict(word2vec_s, newdata = wv, type = "nearest", top_n = 3)
##      term similarity rank
## 1 lovejoy  0.9980072    1
## 2   helen  0.9908397    2
## 3   rabbi  0.8993034    3

¿Qué es a Flanders lo que Marge es a Homero?

wv <- predict(word2vec_s, newdata = c("marge", "flanders", "homer"), type = "embedding")
wv <- wv["marge", ] - wv["homer", ] + wv["flanders", ]
predict(word2vec_s, newdata = wv, type = "nearest", top_n = 3)
##           term similarity rank
## 1        maude  0.9988592    1
## 2 ned_flanders  0.9886520    2
## 3         oooh  0.9885184    3

¿Qué es a “mujer” lo que “Homero” es a “muchacho”? O lo mismo pero con “Bart”

wv <- predict(word2vec_s, newdata = c("woman", "homer", "guy"), type = "embedding")
wv <- wv["homer", ] - wv["guy", ] + wv["woman", ]
predict(word2vec_s, newdata = wv, type = "nearest", top_n = 3)
##            term similarity rank
## 1 marge_simpson  0.9991993    1
## 2   simpson_sir  0.9969260    2
## 3   homer_marge  0.9890501    3
wv <- predict(word2vec_s, newdata = c("woman", "bart_simpson", "guy"), type = "embedding")
wv <- wv["bart_simpson", ] - wv["guy", ] + wv["woman", ]
predict(word2vec_s, newdata = wv, type = "nearest", top_n = 3)
##           term similarity rank
## 1    bart_lisa  0.9940181    1
## 2 lisa_simpson  0.9886613    2
## 3      simpson  0.9851313    3

Ahora vamos a visualizar el embedding usando TSNE, una tecnica no lineal de reducción de dimensionalidad. Primero, transformamos en un formato manejable el embedding y lo reproyectamos en dos ejes con TSNE:

word2vec_s_tidy <- word2vec_s %>% 
        as.matrix() %>%
        as_tibble(rownames = "word")
## Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if
## `.name_repair` is omitted as of tibble 2.0.0.
## ℹ Using compatibility `.name_repair`.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
tictoc::tic()
tsne_w2v <- Rtsne(word2vec_s_tidy %>% select(-word),
                  theta=0.1,
                  pca_scale=TRUE,
                  )
tictoc::toc()
## 103.12 sec elapsed
 pl<-tsne_w2v$Y %>%
                as_tibble() %>%
                bind_cols(word2vec_s_tidy %>% select(word)) %>%
        ggplot() + 
                geom_text(aes(x=V1, y=V2, label=word), size=2) +
                theme_minimal()

pl

No pareciera verse demasiado claro… veamos si podemos hacer algo al respecto usando ggplotly:

ggplotly(pl)

Ahora vamos a armar una funcion que plotea una palabra (target), las palabras más cercanas (near), las más lejanas (far) y una muestra aleatoria (random).

plot_wv <- function(cw, w2v_matrix=word2vec_s_tidy, tsne_matrix=tsne_w2v, n_words=8){
        dist <- predict(word2vec_s, 
                             newdata = cw, 
                             type = "nearest", 
                             top_n = nrow(word2vec_s_tidy))[[1]] %>%
                as_tibble() %>%
                janitor::clean_names()%>%
                arrange(desc(similarity))
        
        nearest <- c(cw, 
             dist %>%
                     head(n_words) %>%
                     select("term2") %>% pull()
             )
        
        
        farest <- c( 
             dist %>%
                     tail(n_words) %>%
                     select("term2") %>% pull()
             )
        
        random <- sample_n(word2vec_s_tidy %>% select(word), size = n_words) %>% pull()

        filters <- c(nearest, farest, random)
        
        tsne_w2v$Y %>%
                as_tibble() %>%
                bind_cols(word2vec_s_tidy %>% select(word)) %>%
                filter(word %in% filters) %>%
                mutate(type = case_when(
                        word == cw ~ '0_target',
                        word %in% nearest ~ '1_near',
                        word %in% farest ~'3_far',
                        word %in% random ~ '2_random'
        )) %>%
        ggplot() + 
                geom_text(aes(x=V1, y=V2, label=word, color=type), size=3) +
                theme_minimal() +
                scale_color_viridis_d(direction=-1)
}


ggplotly(plot_wv("moe"))
ggplotly(plot_wv("marge"))