From f54cb60a520f69edaadf83a436ff05a8ff814c4c Mon Sep 17 00:00:00 2001 From: jjallaire Date: Tue, 31 Jul 2018 11:29:47 -0400 Subject: [PATCH] narrow text a bit --- vignettes/tutorial_basic_regression.Rmd | 78 ++++++++++++++----------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/vignettes/tutorial_basic_regression.Rmd b/vignettes/tutorial_basic_regression.Rmd index 17e7d3ef5..61719a4af 100644 --- a/vignettes/tutorial_basic_regression.Rmd +++ b/vignettes/tutorial_basic_regression.Rmd @@ -79,8 +79,8 @@ train_data[1, ] # Display sample features, notice the different scales ``` ``` -[1] 1.23247 0.00000 8.14000 0.00000 0.53800 6.14200 91.70000 3.97690 4.00000 307.00000 -[11] 21.00000 396.90000 18.72000 +[1] 1.23247 0.00000 8.14000 0.00000 0.53800 6.14200 91.70000 3.97690 +[9] 4.00000 307.00000 21.00000 396.90000 18.72000 ``` Let's add column names for better data inspection. @@ -89,7 +89,8 @@ Let's add column names for better data inspection. ```{r} library(tibble) -column_names <- c('CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT') +column_names <- c('CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', + 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT') train_df <- as_tibble(train_data) colnames(train_df) <- column_names @@ -133,11 +134,14 @@ train_labels[1:10] # Display first 10 entries It's recommended to normalize features that use different scales and ranges. Although the model might converge without feature normalization, it makes training more difficult, and it makes the resulting model more dependant on the choice of units used in the input. ```{r} -train_data <- scale(train_data) # Test data is *not* used when calculating the mean and std. +# Test data is *not* used when calculating the mean and std. + +# Normalize training data +train_data <- scale(train_data) + +# Use means and standard deviations from training set to normalize test set col_means_train <- attr(train_data, "scaled:center") col_stddevs_train <- attr(train_data, "scaled:scale") - -# use means and standard deviations from training set to normalize test set test_data <- scale(test_data, center = col_means_train, scale = col_stddevs_train) train_data[1, ] # First training sample, normalized @@ -157,17 +161,17 @@ Let's build our model. Here, we'll use a `sequential` model with two densely con build_model <- function() { model <- keras_model_sequential() %>% - layer_dense(units = 64, - activation = "relu", - input_shape = dim(train_data)[2] - ) %>% + layer_dense(units = 64, activation = "relu", + input_shape = dim(train_data)[2]) %>% layer_dense(units = 64, activation = "relu") %>% layer_dense(units = 1) - model %>% compile(loss = "mse", - optimizer = optimizer_rmsprop(), - metrics = list("mean_absolute_error") - ) + model %>% compile( + loss = "mse", + optimizer = optimizer_rmsprop(), + metrics = list("mean_absolute_error") + ) + model } @@ -176,17 +180,19 @@ model %>% summary() ``` ``` -Layer (type) Output Shape Param # -================================================================================================================ -dense_1 (Dense) (None, 64) 896 -________________________________________________________________________________________________________________ -dense_2 (Dense) (None, 64) 4160 -________________________________________________________________________________________________________________ -dense_3 (Dense) (None, 1) 65 -================================================================================================================ +_____________________________________________________________________________________ +Layer (type) Output Shape Param # +===================================================================================== +dense_5 (Dense) (None, 64) 896 +_____________________________________________________________________________________ +dense_6 (Dense) (None, 64) 4160 +_____________________________________________________________________________________ +dense_7 (Dense) (None, 1) 65 +===================================================================================== Total params: 5,121 Trainable params: 5,121 Non-trainable params: 0 +_____________________________________________________________________________________ ``` @@ -206,7 +212,7 @@ print_dot_callback <- callback_lambda( epochs <- 500 -# Store training stats +# Fit the model and store training stats history <- model %>% fit( train_data, train_labels, @@ -278,17 +284,21 @@ test_predictions[ , 1] ``` ``` - [1] 8.007033 18.153156 21.616066 31.284706 26.284445 19.741478 28.916748 22.962608 19.520395 21.477909 - [11] 19.306627 17.479231 14.995399 43.237667 18.181221 20.486805 28.540936 21.580551 18.925463 37.664436 - [21] 10.997092 13.498063 20.555744 14.972000 23.043034 25.181206 31.070011 33.881676 10.061424 22.055264 - [31] 19.047733 12.997090 34.065460 26.652784 17.594194 7.601262 14.989892 17.475399 19.075344 29.014477 - [41] 31.474913 29.175596 13.835101 42.377480 30.460976 26.400883 28.388220 16.290909 22.766804 23.140137 - [51] 38.747765 20.003471 11.512418 15.724887 36.897202 29.616425 12.119034 49.743866 34.149521 24.879330 - [61] 23.744190 16.407722 12.654126 18.168047 24.095661 24.611050 12.430396 24.793139 13.584899 7.653455 - [71] 37.167645 32.139378 25.921822 14.376559 27.738743 19.491718 21.113312 24.868862 37.598633 9.865499 - [81] 20.322922 39.640156 15.201268 13.107138 17.114182 20.144257 20.070990 20.365595 23.018572 31.493315 - [91] 21.859282 22.314922 27.089687 45.900990 37.716213 17.741659 37.281250 53.972012 27.489983 43.101936 -[101] 32.614670 19.385052 + [1] 9.159314 17.955666 20.573296 31.487156 25.231384 18.803967 27.139153 + [8] 21.052799 18.904579 22.056618 19.140137 17.342262 15.233129 42.001091 +[15] 19.280727 19.559774 27.276485 20.737257 19.391312 38.450863 12.431134 +[22] 16.025173 19.910103 14.362184 20.846870 24.595688 31.234753 30.109112 +[29] 11.271907 21.081585 18.724422 14.542423 33.109241 25.842684 18.071476 +[36] 9.046785 14.701134 17.113651 21.169674 27.008324 29.676132 28.280304 +[43] 15.355518 41.252007 29.731274 24.258526 25.582203 16.032135 24.014944 +[50] 22.071520 35.658638 19.342590 13.662583 15.854269 34.375328 28.051319 +[57] 13.002036 47.801872 33.513954 23.775620 25.214602 17.864346 14.284246 +[64] 17.458893 22.757492 22.424841 13.578171 22.530212 15.667303 7.438343 +[71] 38.318726 29.219141 25.282124 15.476329 24.670732 17.125381 20.079552 +[78] 23.601147 34.540359 12.151771 19.177418 37.980789 15.576267 14.904464 +[85] 17.581717 17.851192 20.480953 19.700697 21.921551 31.415789 19.116734 +[92] 21.192280 24.934101 41.778465 35.113403 19.307007 35.754066 53.983509 +[99] 26.797831 44.472233 32.520882 19.591730 ``` ## Conclusion