Cómo crear un árbol de decisión en R desde cero.
En este artículo, vamos a crear un árbol de decisión para predecir la supervivencia de los pasajeros en función de tres variables: age, sex y Pclass (passenger class). Las variables usadas en el árbol son relevantes, si bien no son el modelo perfecto.
Ante todo, queremos un modelo simple que nos aporte compresión sobre el problema y nos permita crear modelos más avanzados.
CARGA DE LIBRERÍAS y DATOS
Cargamos las librerias y los datos de entrenamiento
library(tidyverse)
library(caTools)
train <- read.csv("train.csv")
test <- read.csv("test.csv")
head(train)
## PassengerId Survived Pclass
## 1 1 0 3
## 2 2 1 1
## 3 3 1 3
## 4 4 1 1
## 5 5 0 3
## 6 6 0 3
## Name Sex Age SibSp Parch
## 1 Braund, Mr. Owen Harris 1 22 1 0
## 2 Cumings, Mrs. John Bradley (Florence Briggs Thayer) 0 38 1 0
## 3 Heikkinen, Miss. Laina 0 26 0 0
## 4 Futrelle, Mrs. Jacques Heath (Lily May Peel) 0 35 1 0
## 5 Allen, Mr. William Henry 1 35 0 0
## 6 Moran, Mr. James 1 NA 0 0
## Ticket Fare Cabin Embarked
## 1 A/5 21171 7.2500 S
## 2 PC 17599 71.2833 C85 C
## 3 STON/O2. 3101282 7.9250 S
## 4 113803 53.1000 C123 S
## 5 373450 8.0500 S
## 6 330877 8.4583 Q
head(test)
## PassengerId Pclass Name Sex Age
## 1 892 3 Kelly, Mr. James male 34.5
## 2 893 3 Wilkes, Mrs. James (Ellen Needs) female 47.0
## 3 894 2 Myles, Mr. Thomas Francis male 62.0
## 4 895 3 Wirz, Mr. Albert male 27.0
## 5 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female 22.0
## 6 897 3 Svensson, Mr. Johan Cervin male 14.0
## SibSp Parch Ticket Fare Cabin Embarked
## 1 0 0 330911 7.8292 Q
## 2 1 0 363272 7.0000 S
## 3 0 0 240276 9.6875 Q
## 4 0 0 315154 8.6625 S
## 5 1 1 3101298 12.2875 S
## 6 0 0 7538 9.2250 S
Librerías propias para trabajar con árboles de decisión
library(rpart) #cálculos
library(rattle)
library(rpart.plot)
PRIMER VISTAZO A LOS DATOS
Antes de lanzarnos a crear un árbol, hay que asegurarse de que los datos están correctamente. Es decir, que no tenemos valores faltantes o outliers que puedan estropearnos todo.
Nos limitaremos a ver las variables que usaremos en el árbol.
Exámen preliminar de la información en la tabla
Lo primero es hacer factors todas las variables categoricas. De otra forma, R se empeña en sacarnos medias y medianas y no tiene sentido (así nos mostrará porcentajes y se entiende mejor)
#para los datos de entrenamiento
train$Survived<-as.factor(train$Survived)
train$Pclass<-as.factor(train$Pclass)
train$Sex<-as.factor(train$Sex)
#para los datos de test
#test$<-as.factor(test$Survived) #esta no existe
test$Pclass<-as.factor(test$Pclass)
#test$Sex<-as.factor(test$Sex) #esto hay que hacerlo mas adelante
Ahora sí, si hacemos summary nos saldrá algo comprensible:
summary(train)
## PassengerId Survived Pclass Name Sex Age
## Min. : 1.0 0:549 1:216 Length:891 0:314 Min. : 0.42
## 1st Qu.:223.5 1:342 2:184 Class :character 1:577 1st Qu.:20.12
## Median :446.0 3:491 Mode :character Median :28.00
## Mean :446.0 Mean :29.70
## 3rd Qu.:668.5 3rd Qu.:38.00
## Max. :891.0 Max. :80.00
## NA's :177
## SibSp Parch Ticket Fare
## Min. :0.000 Min. :0.0000 Length:891 Min. : 0.00
## 1st Qu.:0.000 1st Qu.:0.0000 Class :character 1st Qu.: 7.91
## Median :0.000 Median :0.0000 Mode :character Median : 14.45
## Mean :0.523 Mean :0.3816 Mean : 32.20
## 3rd Qu.:1.000 3rd Qu.:0.0000 3rd Qu.: 31.00
## Max. :8.000 Max. :6.0000 Max. :512.33
##
## Cabin Embarked
## Length:891 Length:891
## Class :character Class :character
## Mode :character Mode :character
##
##
##
##
summary(test)
## PassengerId Pclass Name Sex Age
## Min. : 892.0 1:107 Length:418 Length:418 Min. : 0.17
## 1st Qu.: 996.2 2: 93 Class :character Class :character 1st Qu.:21.00
## Median :1100.5 3:218 Mode :character Mode :character Median :27.00
## Mean :1100.5 Mean :30.27
## 3rd Qu.:1204.8 3rd Qu.:39.00
## Max. :1309.0 Max. :76.00
## NA's :86
## SibSp Parch Ticket Fare
## Min. :0.0000 Min. :0.0000 Length:418 Min. : 0.000
## 1st Qu.:0.0000 1st Qu.:0.0000 Class :character 1st Qu.: 7.896
## Median :0.0000 Median :0.0000 Mode :character Median : 14.454
## Mean :0.4474 Mean :0.3923 Mean : 35.627
## 3rd Qu.:1.0000 3rd Qu.:0.0000 3rd Qu.: 31.500
## Max. :8.0000 Max. :9.0000 Max. :512.329
## NA's :1
## Cabin Embarked
## Length:418 Length:418
## Class :character Class :character
## Mode :character Mode :character
##
##
##
##
Está claro que tenemos un problema con la variable Age, que tiene 177 +1 NA. Además, la variable sex toma valores male y female en los datos de test y hay que convertirlos previamente.
PRIMERAS CORRECCIONES A LOS DATOS
Male y female en la tabla de test
Como comentábamos antes, Sex toma valores female,male en lugar de 0,1. Tenemos que convertirlos adecuadamente. Male=1, Female=0.
for(i in 1:nrow(test))
{
if (test$Sex[i]=='male') test$Sex[i]<-1
if (test$Sex[i]=='female') test$Sex[i]<-0
}
#ahora si lo convertimos a variable categorica. Si lo hacemos antes da error aqui.
test$Sex<-as.factor(test$Sex)
IMPUTANDO VALORES A AGE
Para imputar valores necesitaremos sacar informacion de los que ya están informados. Tenemos principalmente tres opciones para imputar Age:
- Imputar a la media o la mediana, según los datos sean o no sesgados.
- Usar machine learning para extraer age del resto de las variables e imputar usando el resultado.
- Ver con qué variables varía age, dividir en grupos e imputar a la media o mediana de cada uno de esos grupos.
Usaremos la primera técnica, por ser la más sencilla, aunque no necesariamente la mejor. El borrado de datos, dada la elevada cantidad de valores NA, ni siquiera nos lo podemos plantear.
AGE_POR_IMPUTAR_TRAIN<-which(is.na(train$Age))
AGE_POR_IMPUTAR_TEST<-which(is.na(test$Age))
Tenemos 263 valores de Age para imputar, más o menos el 20% del total. Una cantidad más que considerable. Notar que sería un error gravísimo borrar las filas que no tienen Age informado, dado que perderíamos mucha información.
Imputación de Age por la mediana
La variable age es sesgada a derecha (véase el gráfico debajo), por lo que, como primera aproximacion, podríamos imputar a la mediana. En general:
Distribuciones sesgadas -> imputación a la mediana Distribuciones simétricas -> imputación a la media
mediana_age<-median(c(train$Age,test$Age),na.rm = TRUE) #mediana = 28
mediana_age
## [1] 28
hist(train$Age)
Cargando la variable Age calculada
Y metemos los datos
for (i in 1:length(AGE_POR_IMPUTAR_TRAIN))
train$Age[AGE_POR_IMPUTAR_TRAIN[i]]<-mediana_age
for (i in 1:length(AGE_POR_IMPUTAR_TEST))
test$Age[AGE_POR_IMPUTAR_TEST[i]]<-mediana_age
Comprobemos que no ha quedado ningún NA
which(is.na(train$Age))
## integer(0)
which(is.na(test$Age))
## integer(0)
La imputación de la edad ha funcionado. Ahora ya tanto la tabla train como test tienen age imputado.
summary(train$Age)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.42 22.00 28.00 29.36 35.00 80.00
summary(test$Age)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.17 23.00 28.00 29.81 35.75 76.00
Dividir los datos de entrenamiento en dos partes
Para poder validar nuestro modelo, necesitaremos dividir los datos de entrenamiento en dos bloques. El propio que usaremos para calcular el árbol de decisión y otro que usaremos para la validación del árbol.
#make this example reproducible
set.seed(123)
#usamos el 70% de los datos para crear el modelo y el 30% para su validación.
sample <- sample.split(train$Age, SplitRatio = 0.7)
trainArbol <- subset(train, sample == TRUE)
trainValida <- subset(train, sample == FALSE)
indices <- which(!sample) #guardamos indices datos validacion
Creación del árbol de decisión con los datos de entrenamiento
En R la creación de un árbol de decisión es tan simple como esto:
#crear el arbol con la particion de entrenamiento
arbol<- rpart(
formula = Survived ~ Age + Sex + Pclass,
data=trainArbol,
method='class') #class es para que calcule 0/1
Una vez tenemos creado el modelo, podemos hacer summary(arbol) pero es más sencillo visualizarlo:
fancyRpartPlot(arbol,main="SUPERVIVENCIA TITANIC",sub="")
Es interesante ver la importancia de cada variable. Del propio gráfico del árbol, vemos que la más importante va a ser el sexo, seguido a la par por las otras dos:
arbol$variable.importance
## Sex Age Pclass
## 88.08892 27.75265 21.36064
El árbol es un modelo de caja blanca, comprensible al humano, aunque inexacto. No suele usarse para grandes problemas en una versión tan básic aunque tiene la gran ventaja de que, al ser legible, puede presentarse y explicarse.
Prediccion para los datos de validación
Vamos a evaluar la precisión de nuestro árbol. La forma habitual es fijarse en métricas como sensibilidad, especificidad y nivel de precisión:
library(caret)
predictions<-predict(arbol, trainValida, type = 'class')
confusionMatrix(data=predictions, reference=trainValida$Survived)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 140 22
## 1 36 70
##
## Accuracy : 0.7836
## 95% CI : (0.7294, 0.8314)
## No Information Rate : 0.6567
## P-Value [Acc > NIR] : 3.959e-06
##
## Kappa : 0.5368
##
## Mcnemar's Test P-Value : 0.08783
##
## Sensitivity : 0.7955
## Specificity : 0.7609
## Pos Pred Value : 0.8642
## Neg Pred Value : 0.6604
## Prevalence : 0.6567
## Detection Rate : 0.5224
## Detection Prevalence : 0.6045
## Balanced Accuracy : 0.7782
##
## 'Positive' Class : 0
##
El resultado no es malo. Sería un error intentar predecir sobre los datos de entrenamiento con mayor precisión de una forma artificial (añadiendo variables que no son buenos predictores, por ejemplo).
De hecho, el problema de los árboles de decisión es que ajustan muy bien a los datos de entrenamiento pero luego, al ser tan acotados, fallan mucho en los de test (“overfitting”).
PREDICCIONES DE SUPERVIVENCIA CON LOS DATOS DE TEST
Veamos quién vive y quién muere según nuestro modelo:
predictions_test<-predict(arbol,test,type='class')
head(predictions_test,20)
## 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
## 0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 1 0 0 1 0
## Levels: 0 1
Podemos calcular también la supervivencia de personas concretas, dados los parámetros sexo, edad y pclass para ellas. Por ejemplo, pongamos una edad de 40, sexo = hombre y Pclass=3. Veamos como el modelo predice lo que todos suponemos que va a pasar. Es decir, que muere:
predict(object = arbol, newdata= data.frame(Age=40, Sex=as.factor(1),Pclass=as.factor(3)))
## 0 1
## 1 0.8250653 0.1749347
La cosa está clara. Este pobre desgraciado tiene todos los números para acabar congelado en medio del mar. Sin embargo, si es una niña de primera clase la cosa cambia:
predict(object = arbol, newdata= data.frame(Age=3, Sex=as.factor(0),Pclass=as.factor(1)))
## 0 1
## 1 0.02608696 0.973913
RESUMEN
En este ejercicio, hemos creado el árbol de decisión para los datos del Titanic. Hemos tenido en cuenta sólo tres variables, para simplificar el modelo y evitar los problemas de ‘overfitting’.
El resultado ha sido un árbol de tres niveles que nos ofrece una precisión del 78%. Pero, más importante que eso, hemos visto que la variable que más información nos ofrece es el sexo, siendo casi determinante en nuestro modelo.
El modelo creado mediante un árbol de decisión, si bien es impreciso, tiene la ventaja de ofrecer una gran claridad en la selección de variables y la posibilidad de ser visualizado en forma sencilla. Este conocimiento aportado nos servirá de base para crear otros modelos más sofisticados en busca de una mayor precisión.
Finalmente, notar que siempre que se crea un modelo de machine learning es necesario separar los datos de entrenamiento en dos particiones, una para calcular el modelo y la otra para validarlo. Esto se llama método de validación hold-out, habiendo otras alternativas, la más conocida de ellas k-fold.