Cómo crear un árbol de decisión en R desde cero.

En este artículo, vamos a crear un árbol de decisión para predecir la supervivencia de los pasajeros en función de tres variables: age, sex y Pclass (passenger class). Las variables usadas en el árbol son relevantes, si bien no son el modelo perfecto.

Ante todo, queremos un modelo simple que nos aporte compresión sobre el problema y nos permita crear modelos más avanzados.

CARGA DE LIBRERÍAS y DATOS

Cargamos las librerias y los datos de entrenamiento

library(tidyverse)
library(caTools)

train <- read.csv("train.csv")
test  <- read.csv("test.csv")

head(train)
##   PassengerId Survived Pclass
## 1           1        0      3
## 2           2        1      1
## 3           3        1      3
## 4           4        1      1
## 5           5        0      3
## 6           6        0      3
##                                                  Name Sex Age SibSp Parch
## 1                             Braund, Mr. Owen Harris   1  22     1     0
## 2 Cumings, Mrs. John Bradley (Florence Briggs Thayer)   0  38     1     0
## 3                              Heikkinen, Miss. Laina   0  26     0     0
## 4        Futrelle, Mrs. Jacques Heath (Lily May Peel)   0  35     1     0
## 5                            Allen, Mr. William Henry   1  35     0     0
## 6                                    Moran, Mr. James   1  NA     0     0
##             Ticket    Fare Cabin Embarked
## 1        A/5 21171  7.2500              S
## 2         PC 17599 71.2833   C85        C
## 3 STON/O2. 3101282  7.9250              S
## 4           113803 53.1000  C123        S
## 5           373450  8.0500              S
## 6           330877  8.4583              Q
head(test)
##   PassengerId Pclass                                         Name    Sex  Age
## 1         892      3                             Kelly, Mr. James   male 34.5
## 2         893      3             Wilkes, Mrs. James (Ellen Needs) female 47.0
## 3         894      2                    Myles, Mr. Thomas Francis   male 62.0
## 4         895      3                             Wirz, Mr. Albert   male 27.0
## 5         896      3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female 22.0
## 6         897      3                   Svensson, Mr. Johan Cervin   male 14.0
##   SibSp Parch  Ticket    Fare Cabin Embarked
## 1     0     0  330911  7.8292              Q
## 2     1     0  363272  7.0000              S
## 3     0     0  240276  9.6875              Q
## 4     0     0  315154  8.6625              S
## 5     1     1 3101298 12.2875              S
## 6     0     0    7538  9.2250              S

Librerías propias para trabajar con árboles de decisión

library(rpart) #cálculos
library(rattle)
library(rpart.plot)

PRIMER VISTAZO A LOS DATOS

Antes de lanzarnos a crear un árbol, hay que asegurarse de que los datos están correctamente. Es decir, que no tenemos valores faltantes o outliers que puedan estropearnos todo.

Nos limitaremos a ver las variables que usaremos en el árbol.

Exámen preliminar de la información en la tabla

Lo primero es hacer factors todas las variables categoricas. De otra forma, R se empeña en sacarnos medias y medianas y no tiene sentido (así nos mostrará porcentajes y se entiende mejor)

#para los datos de entrenamiento
train$Survived<-as.factor(train$Survived)
train$Pclass<-as.factor(train$Pclass)
train$Sex<-as.factor(train$Sex)

#para los datos de test
#test$<-as.factor(test$Survived) #esta no existe
test$Pclass<-as.factor(test$Pclass)
#test$Sex<-as.factor(test$Sex) #esto hay que hacerlo mas adelante

Ahora sí, si hacemos summary nos saldrá algo comprensible:

summary(train)
##   PassengerId    Survived Pclass      Name           Sex          Age       
##  Min.   :  1.0   0:549    1:216   Length:891         0:314   Min.   : 0.42  
##  1st Qu.:223.5   1:342    2:184   Class :character   1:577   1st Qu.:20.12  
##  Median :446.0            3:491   Mode  :character           Median :28.00  
##  Mean   :446.0                                               Mean   :29.70  
##  3rd Qu.:668.5                                               3rd Qu.:38.00  
##  Max.   :891.0                                               Max.   :80.00  
##                                                              NA's   :177    
##      SibSp           Parch           Ticket               Fare       
##  Min.   :0.000   Min.   :0.0000   Length:891         Min.   :  0.00  
##  1st Qu.:0.000   1st Qu.:0.0000   Class :character   1st Qu.:  7.91  
##  Median :0.000   Median :0.0000   Mode  :character   Median : 14.45  
##  Mean   :0.523   Mean   :0.3816                      Mean   : 32.20  
##  3rd Qu.:1.000   3rd Qu.:0.0000                      3rd Qu.: 31.00  
##  Max.   :8.000   Max.   :6.0000                      Max.   :512.33  
##                                                                      
##     Cabin             Embarked        
##  Length:891         Length:891        
##  Class :character   Class :character  
##  Mode  :character   Mode  :character  
##                                       
##                                       
##                                       
## 
summary(test)
##   PassengerId     Pclass      Name               Sex                 Age       
##  Min.   : 892.0   1:107   Length:418         Length:418         Min.   : 0.17  
##  1st Qu.: 996.2   2: 93   Class :character   Class :character   1st Qu.:21.00  
##  Median :1100.5   3:218   Mode  :character   Mode  :character   Median :27.00  
##  Mean   :1100.5                                                 Mean   :30.27  
##  3rd Qu.:1204.8                                                 3rd Qu.:39.00  
##  Max.   :1309.0                                                 Max.   :76.00  
##                                                                 NA's   :86     
##      SibSp            Parch           Ticket               Fare        
##  Min.   :0.0000   Min.   :0.0000   Length:418         Min.   :  0.000  
##  1st Qu.:0.0000   1st Qu.:0.0000   Class :character   1st Qu.:  7.896  
##  Median :0.0000   Median :0.0000   Mode  :character   Median : 14.454  
##  Mean   :0.4474   Mean   :0.3923                      Mean   : 35.627  
##  3rd Qu.:1.0000   3rd Qu.:0.0000                      3rd Qu.: 31.500  
##  Max.   :8.0000   Max.   :9.0000                      Max.   :512.329  
##                                                       NA's   :1        
##     Cabin             Embarked        
##  Length:418         Length:418        
##  Class :character   Class :character  
##  Mode  :character   Mode  :character  
##                                       
##                                       
##                                       
## 

Está claro que tenemos un problema con la variable Age, que tiene 177 +1 NA. Además, la variable sex toma valores male y female en los datos de test y hay que convertirlos previamente.

PRIMERAS CORRECCIONES A LOS DATOS

Male y female en la tabla de test

Como comentábamos antes, Sex toma valores female,male en lugar de 0,1. Tenemos que convertirlos adecuadamente. Male=1, Female=0.

for(i in 1:nrow(test))
    {
    if (test$Sex[i]=='male') test$Sex[i]<-1  
    if (test$Sex[i]=='female') test$Sex[i]<-0  
}

#ahora si lo convertimos a variable categorica. Si lo hacemos antes da error aqui.
test$Sex<-as.factor(test$Sex)

IMPUTANDO VALORES A AGE

Para imputar valores necesitaremos sacar informacion de los que ya están informados. Tenemos principalmente tres opciones para imputar Age:

  • Imputar a la media o la mediana, según los datos sean o no sesgados.
  • Usar machine learning para extraer age del resto de las variables e imputar usando el resultado.
  • Ver con qué variables varía age, dividir en grupos e imputar a la media o mediana de cada uno de esos grupos.

Usaremos la primera técnica, por ser la más sencilla, aunque no necesariamente la mejor. El borrado de datos, dada la elevada cantidad de valores NA, ni siquiera nos lo podemos plantear.

AGE_POR_IMPUTAR_TRAIN<-which(is.na(train$Age))
AGE_POR_IMPUTAR_TEST<-which(is.na(test$Age))

Tenemos 263 valores de Age para imputar, más o menos el 20% del total. Una cantidad más que considerable. Notar que sería un error gravísimo borrar las filas que no tienen Age informado, dado que perderíamos mucha información.

Imputación de Age por la mediana

La variable age es sesgada a derecha (véase el gráfico debajo), por lo que, como primera aproximacion, podríamos imputar a la mediana. En general:

Distribuciones sesgadas -> imputación a la mediana Distribuciones simétricas -> imputación a la media

mediana_age<-median(c(train$Age,test$Age),na.rm = TRUE) #mediana = 28
mediana_age
## [1] 28
hist(train$Age)

Cargando la variable Age calculada

Y metemos los datos

for (i in 1:length(AGE_POR_IMPUTAR_TRAIN))
    train$Age[AGE_POR_IMPUTAR_TRAIN[i]]<-mediana_age

for (i in 1:length(AGE_POR_IMPUTAR_TEST))
    test$Age[AGE_POR_IMPUTAR_TEST[i]]<-mediana_age

Comprobemos que no ha quedado ningún NA

which(is.na(train$Age))
## integer(0)
which(is.na(test$Age))
## integer(0)

La imputación de la edad ha funcionado. Ahora ya tanto la tabla train como test tienen age imputado.

summary(train$Age)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    0.42   22.00   28.00   29.36   35.00   80.00
summary(test$Age)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    0.17   23.00   28.00   29.81   35.75   76.00

Dividir los datos de entrenamiento en dos partes

Para poder validar nuestro modelo, necesitaremos dividir los datos de entrenamiento en dos bloques. El propio que usaremos para calcular el árbol de decisión y otro que usaremos para la validación del árbol.

#make this example reproducible
set.seed(123)

#usamos el 70% de los datos para crear el modelo y el 30% para su validación.
sample <- sample.split(train$Age, SplitRatio = 0.7)
trainArbol  <- subset(train, sample == TRUE)
trainValida   <- subset(train, sample == FALSE)
indices <- which(!sample) #guardamos indices datos validacion

Creación del árbol de decisión con los datos de entrenamiento

En R la creación de un árbol de decisión es tan simple como esto:

#crear el arbol con la particion de entrenamiento
arbol<- rpart(
              formula = Survived ~ Age + Sex + Pclass, 
              data=trainArbol, 
              method='class')    #class es para que calcule 0/1

Una vez tenemos creado el modelo, podemos hacer summary(arbol) pero es más sencillo visualizarlo:

fancyRpartPlot(arbol,main="SUPERVIVENCIA TITANIC",sub="")

Es interesante ver la importancia de cada variable. Del propio gráfico del árbol, vemos que la más importante va a ser el sexo, seguido a la par por las otras dos:

arbol$variable.importance
##      Sex      Age   Pclass 
## 88.08892 27.75265 21.36064

El árbol es un modelo de caja blanca, comprensible al humano, aunque inexacto. No suele usarse para grandes problemas en una versión tan básic aunque tiene la gran ventaja de que, al ser legible, puede presentarse y explicarse.

Prediccion para los datos de validación

Vamos a evaluar la precisión de nuestro árbol. La forma habitual es fijarse en métricas como sensibilidad, especificidad y nivel de precisión:

library(caret)
predictions<-predict(arbol, trainValida, type = 'class')
confusionMatrix(data=predictions, reference=trainValida$Survived)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 140  22
##          1  36  70
##                                           
##                Accuracy : 0.7836          
##                  95% CI : (0.7294, 0.8314)
##     No Information Rate : 0.6567          
##     P-Value [Acc > NIR] : 3.959e-06       
##                                           
##                   Kappa : 0.5368          
##                                           
##  Mcnemar's Test P-Value : 0.08783         
##                                           
##             Sensitivity : 0.7955          
##             Specificity : 0.7609          
##          Pos Pred Value : 0.8642          
##          Neg Pred Value : 0.6604          
##              Prevalence : 0.6567          
##          Detection Rate : 0.5224          
##    Detection Prevalence : 0.6045          
##       Balanced Accuracy : 0.7782          
##                                           
##        'Positive' Class : 0               
## 

El resultado no es malo. Sería un error intentar predecir sobre los datos de entrenamiento con mayor precisión de una forma artificial (añadiendo variables que no son buenos predictores, por ejemplo).

De hecho, el problema de los árboles de decisión es que ajustan muy bien a los datos de entrenamiento pero luego, al ser tan acotados, fallan mucho en los de test (“overfitting”).

PREDICCIONES DE SUPERVIVENCIA CON LOS DATOS DE TEST

Veamos quién vive y quién muere según nuestro modelo:

predictions_test<-predict(arbol,test,type='class')
head(predictions_test,20)
##  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 
##  0  0  0  0  1  0  1  0  1  0  0  0  1  0  1  1  0  0  1  0 
## Levels: 0 1

Podemos calcular también la supervivencia de personas concretas, dados los parámetros sexo, edad y pclass para ellas. Por ejemplo, pongamos una edad de 40, sexo = hombre y Pclass=3. Veamos como el modelo predice lo que todos suponemos que va a pasar. Es decir, que muere:

predict(object = arbol, newdata= data.frame(Age=40, Sex=as.factor(1),Pclass=as.factor(3)))
##           0         1
## 1 0.8250653 0.1749347

La cosa está clara. Este pobre desgraciado tiene todos los números para acabar congelado en medio del mar. Sin embargo, si es una niña de primera clase la cosa cambia:

predict(object = arbol, newdata= data.frame(Age=3, Sex=as.factor(0),Pclass=as.factor(1)))
##            0        1
## 1 0.02608696 0.973913

RESUMEN

En este ejercicio, hemos creado el árbol de decisión para los datos del Titanic. Hemos tenido en cuenta sólo tres variables, para simplificar el modelo y evitar los problemas de ‘overfitting’.

El resultado ha sido un árbol de tres niveles que nos ofrece una precisión del 78%. Pero, más importante que eso, hemos visto que la variable que más información nos ofrece es el sexo, siendo casi determinante en nuestro modelo.

El modelo creado mediante un árbol de decisión, si bien es impreciso, tiene la ventaja de ofrecer una gran claridad en la selección de variables y la posibilidad de ser visualizado en forma sencilla. Este conocimiento aportado nos servirá de base para crear otros modelos más sofisticados en busca de una mayor precisión.

Finalmente, notar que siempre que se crea un modelo de machine learning es necesario separar los datos de entrenamiento en dos particiones, una para calcular el modelo y la otra para validarlo. Esto se llama método de validación hold-out, habiendo otras alternativas, la más conocida de ellas k-fold.

Comentarios

Entradas populares de este blog

UNA BREVE EXPLICACIÓN DE LA LEY DE LAS ESPERANZAS ITERADAS

UN EJEMPLO DETALLADO DE PRE-PROCESAMIENTO EN R - IMPUTACIÓN DE DATOS FALTANTES

Un mundo de sucesos imposibles - El "tongo" de la Bonoloto.