-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSupervisedLearning-IBM Attrition.Rmd
380 lines (273 loc) · 12.9 KB
/
SupervisedLearning-IBM Attrition.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
---
pagetitle: "A05"
output: html_document
---
```{css, echo=FALSE}
body {
font: helvetica, sans-serif;
}
```
<body>
<img src="https://mem.masters.upc.edu/++genweb++static/images/logoUPC.png" />
<h1>IBM Attrition</h1>
<h2>Machine Learning Assignment</h2>
<h3>Date: 20-Apr-2021 <br />
By: Babak Barghi, Alexander Rutten, Will Rains
</h3>
<p>
<br />
</p>
<p><span STYLE="font-weight:bold">Assignment Instructions:</span>
<br />
This is a fictional people analytics data set created by IBM data scientists:
https://www.kaggle.com/pavansubhasht/ibm-hr-analytics-attrition-dataset
Your job is to build a model to predict if an employee will leave the company. This is defined by the categorical variable Attrition. The employees leaving the company have a value of Attrition equal to Yes. You can consider several classification models like decision trees, logistic regression or random forests.
You can see how can you implement classification models in:
https://parsnip.tidymodels.org/reference/
Select a good model performing cross validation and tuning hyperparameters on the training set. Once you have defined your model, test its performance on the test set. You must deliver a GroupX.Rmd file, where X is the name of your group. This file will include your analysis, reading the .csv file in the same directory as the file.
</p>
<h3>Introduction</h3>
In the following report, we used data from *IBM HR Analytics Employee Attrition & Performance*, which includes a total of 35 variables to apply a machine learning method. The goal of this task is to predict Attrition by selecting a set of explanatory variables and building a random forest classification tree.
```{r setup, message=FALSE, warning=FALSE}
library(tidyverse)
library(tidymodels)
library(parsnip)
library(ggthemes)
library(skimr)
library(workflows)
library(yardstick)
library(kableExtra)
library(corrplot)
library(RColorBrewer)
knitr::opts_chunk$set(echo = TRUE)
```
<h3>Examine Data</h3>
<h4>Data Importing, Cleaning, Exploration</h4>
```{r , message=FALSE, warning=FALSE}
#Import data
atr_data_raw <- read_csv("WA_Fn-UseC_-HR-Employee-Attrition.csv")
skim(atr_data_raw)
summary(atr_data_raw)
```
This dataset contains 1470 observations with 35 variables to consider.
Using the *summary* and *skim* functions, there is a broad view over the data frame. It is obvious that some of the variables are not useful.
* Employee Count is equal to 1 for all the observations which can not bring value for the analysis.
* Over18 is also equal to 'Y' for all the employees that means nobody is less than 18 years old.
* StandardHours is equal to 80 for all the rows.
* EmployeeNumber is only practical for identifying the specific person which is not helpful for the purposes of this analysis.
```{r}
atr_data <- atr_data_raw
# remove useless variables
atr_data <- atr_data %>% select(-EmployeeCount, -Over18, -StandardHours, -EmployeeNumber)
```
After removing unnecessary variables and having no NA value, it is the time to check on the type of the variables.
There are some attributes that are actually categorical, but in the data set are integer. They will be converted to factor in order to avoid numerical assessment.
```{r}
atr_data <- atr_data %>%
mutate_if(is.character, as.factor) %>%
select(Attrition, everything())
```
Also some features are actually factors, even if listed as integers. Thus they will be converted to the right format.
```{r}
atr_data$Education<- as.factor(atr_data$Education)
atr_data$JobSatisfaction<- as.factor(atr_data$JobSatisfaction)
atr_data$NumCompaniesWorked<- as.factor(atr_data$NumCompaniesWorked)
atr_data$PerformanceRating<- as.factor(atr_data$PerformanceRating)
atr_data$RelationshipSatisfaction<- as.factor(atr_data$RelationshipSatisfaction)
atr_data$StockOptionLevel<- as.factor(atr_data$StockOptionLevel)
atr_data$WorkLifeBalance<- as.factor(atr_data$WorkLifeBalance)
atr_data$EnvironmentSatisfaction<- as.factor(atr_data$EnvironmentSatisfaction)
atr_data$JobInvolvement<- as.factor(atr_data$JobInvolvement)
atr_data$JobLevel<- as.factor(atr_data$JobLevel)
```
<p>
<span STYLE="font-weight:bold">Steps taken to import/clean data:</span>
<ol>
<li>Imported Data from file ("WA_Fn-UseC_-HR-Employee-Attrition.csv") in folder</li>
<li>Changed variables from text or numeric to factor depending on if they would be better served as a factor</li>
<li>Removed a useless variable that only contained 1's to reduce dataset size as we can instead use R funtcitons to count records and such</li>
</ol>
</p>
<h4>Data Visualizations</h4>
In this part we explore the data set based on the target variable to have a better understanding.
Attrition percentage in general.
```{r, fig.align="center", fig.pos="H"}
atr_data %>%
group_by(Attrition) %>%
summarise(Average = n() / nrow(atr_data)*100) %>%
ggplot(aes(Attrition, Average, fill = Attrition)) +
geom_col() +
ylab("Percentage") +
geom_text(aes(Attrition, Average + 3, label = round(Average, 2))) +
scale_fill_brewer(palette = "RdYlBu") +
theme_minimal()
```
Gender differences in Attrition.
```{r, fig.align="center", fig.pos="H"}
atr_data %>%
group_by(Gender, Attrition) %>%
count(Attrition) %>%
ggplot(aes(x = Gender, y = n, fill = Attrition)) +
geom_col(position = "stack") +
labs(
x = "Gender",
y = "No. of Employees") +
scale_fill_brewer(palette = "RdYlBu") +
theme_minimal()
```
Attrition by Department and average Salary.
```{r, fig.align="center", fig.pos="H", message=FALSE}
atr_data %>%
group_by(Attrition, Department) %>%
summarise(average_salary = mean(MonthlyIncome)) %>%
ggplot(aes(Department, average_salary, fill = Attrition)) +
geom_col(position = "dodge") +
ylab("Average Salary") +
theme(axis.title.x = element_blank(),
axis.ticks.x = element_blank()) +
scale_fill_brewer(palette = "RdYlBu") +
theme_minimal()
```
Distance from home & Job Role by Attrition.
```{r, fig.height = 8, fig.width = 13, fig.align="center", fig.pos="H", message=FALSE}
atr_data %>%
ggplot(aes(x = JobRole, y = DistanceFromHome, color = Attrition)) +
geom_boxplot() +
geom_jitter(width=0.22, alpha=0.2) +
scale_x_discrete(guide = guide_axis(n.dodge=2)) +
labs(
y = "Distance From Home",
x = "Job Role") +
theme_minimal()
```
<section id = "corrplot">
<p>
<br />
Now we will take a look at a corrplot to get an initial indication of correlation within the dataset. For this we have made a few adjustments including changing the attrition variable to a number variable as it is the most important variable to understand correlation with in the dataset
</p>
```{r fig.height = 9, fig.width = 15, warning=FALSE, message=FALSE}
library(corrplot)
library(RColorBrewer)
atr_data_numonly <- atr_data_raw %>%
mutate(Attrition = case_when(Attrition == 'Yes' ~ 1,
Attrition == 'No' ~ 0)) %>%
select_if(is_numeric)
atr_data_numonly[is.na(atr_data_numonly)] <- 0
acorr <-cor(atr_data_numonly)
corrplot(acorr, type="upper", na.label = "N",
col=brewer.pal(n=8, name="RdYlBu"))
```
</section>
Here you can see that the correlation plot shows that no variable has a particularly strong correlation positively or negatively with attrition suggesting that an advanced algorithm would be needed such as one of machine learning variety to be able to predict much from this dataset for the variable attrition.
<h3>Pre-process/Split Data</h3>
The only pre-processing before splitting is changing certain field data types to factors which is needed for modeling and removing some useless fields that may slow down run times and obscure results if kept in the data frame.
After these pre-processing steps we will start with defining train and test data frames:
```{r}
set.seed(42) #split with 0.7 ratio
atr_split <- initial_split(atr_data, prop = 0.7)
training(atr_split) %>% glimpse() #train set
testing(atr_split) %>% glimpse() #test set
```
Next step is to use **recipe** package for data pre-processing.
```{r}
atr_recipe <- training(atr_split) %>%
recipe(Attrition ~.) %>%
step_corr(all_numeric()) %>%
step_center(all_numeric(), -all_outcomes()) %>%
step_scale(all_numeric(), -all_outcomes()) %>%
prep()
atr_recipe
```
The recipe has removed no variable.
<h3>Train</h3>
<h4>Define Model, cross validation, tuning hyperparameters</h4>
<p>
A random forest model is trained using the *ranger* package.
<span STYLE="font-weight:bold">Random forest</span> algorithm was selected for it's particularly great ability to build classification models that are robust but also remove the particular weakness of other decision tree algorithms in the over-fitting that tends to occur.
</p>
```{r}
atr_ranger <- rand_forest(trees = 1000, mode = "classification") %>%
set_engine("ranger")
```
Fitting the model
By the *workflows* package we integrate the recipe and the model in a workflow to fit the model using the train set.
```{r}
atr_ranger_wf <- workflow() %>%
add_recipe(atr_recipe) %>%
add_model(atr_ranger) %>%
fit(training(atr_split))
```
Predicting the train set
```{r}
atr_pred_train <- atr_ranger_wf %>%
predict(training(atr_split)) %>%
bind_cols(training(atr_split))
atr_pred_train %>% glimpse()
```
Now that we have obtained the model it is the time to evaluate the model performance using a confusion matrix.
```{r}
atr_pred_train %>%
conf_mat(truth = Attrition, estimate = .pred_class)
```
To evaluate the *specificity* and the *sensitivity*, the following formulas are used:
$$ Sensitivity = \frac{TN}{TN+FP} $$
$$ Specificity = \frac{TP}{TP+FN} $$
Now we calculate the performance metrics of the model using **metric_set**.
```{r}
class_metrics <- metric_set(accuracy, precision, recall, sensitivity, specificity)
atr_pred_train %>%
class_metrics(truth = Attrition, estimate = .pred_class) %>%
kbl(caption = "Random Forest Model Performance - Train") %>%
kable_classic(full_width = F)
```
Here we can see that the results of the testing on the training set showed fairly well in that all but the specificity was very high percentages. Further validation needs to be performed on the testing set to make sure that we have not overfitted.
<h3>Validate</h3>
Confusion matrix (test set)
We can check for overfitting calculating performance metrics for the test set:
```{r}
atr_ranger_wf %>%
predict(testing(atr_split)) %>%
bind_cols(testing(atr_split)) %>%
conf_mat(truth = Attrition, estimate = .pred_class)
```
Now we see the metrics for the test set as below.
```{r}
atr_ranger_wf %>%
predict(testing(atr_split)) %>%
bind_cols(testing(atr_split)) %>%
class_metrics(truth = Attrition, estimate = .pred_class) %>%
kbl(caption = "Random Forest Model Performance - Test") %>%
kable_classic(full_width = F)
```
Here we can see that the model when validated against the test set performed worse than against the data in the training set. Particularly the specificity metric.
Below we continue validations with further cross validations
Let’s define a workflow with the atr_recipe and atr_ranger we have previously defined:
```{r}
atr_ranger_wf <- workflow() %>%
add_recipe(atr_recipe) %>%
add_model(atr_ranger)
#and define a cross-validation with five folds on the training set:
atr_folds <- vfold_cv(training(atr_split), v = 5)
# and perform cross validation:
fit_resamples(atr_ranger_wf, atr_folds) %>%
collect_metrics()
```
<h3>Hyperparameter Tuning</h3>
```{r}
atr_grid <- expand.grid(mtry = c(1,5,10), trees = c(500, 1000, 1500))
atr_grid
atr_ranger <- rand_forest(mode = "classification", mtry = tune(), trees = tune()) %>%
set_engine("ranger")
atr_tune <- tune_grid(object = atr_ranger,
preprocessor = atr_recipe,
resamples = atr_folds,
grid = atr_grid,
metrics = metric_set(accuracy, roc_auc))
show_best(atr_tune, metric = "accuracy")
```
Here we have run some hyperparameter tuning to see if we have chosen the optimal settings for our random forest model. The results have suggested that we chosen well with 1000 trees and the default *mtry* values.
<h3>Conclusions</h3>
Attrition is a very important metric for businesses to monitor. Retraining is very expensively and timely and replacing people can be very difficult particularly if they had a lot of business process knowledge specific to the company.
If an HR department could predict attrition, they could take steps to that person to reduce the likelihood of them leaving or identify the potential causes that might make someone want to leave the company. Even before hiring an employee, if HR can obtain the prediction result that whether this candidate will leave the company soon or will stay for a long time, it will be an efficient method to control the cost and reduce the risk.
We have done just that above. Built a model using the random forest algorithm to predict if someone will leave the company. This methodology and analysis could be of great use to an HR department and is another example of why these methodologies and techniques can be revolutionary in business.
</body>