Code
source("../../dsan-globals/_globals.r")DSAN 5300: Statistical Learning
source("../../dsan-globals/_globals.r")nnet in RSince the ISLR lab within the Deep Learning chapter uses R’s keras library rather than nnet1, here are some quick examples of how nnet works that may help you get started on the R portion of Lab 9.
Here we’ll load the Advertising.csv dataset used in the beginning of ISLR:
library(tidyverse) |> suppressPackageStartupMessages()
ad_df <- read_csv("https://www.statlearning.com/s/Advertising.csv", show_col_types = FALSE)New names:
• `` -> `...1`
colnames(ad_df) <- c("id", colnames(ad_df)[2:5])
ad_df |> head()| id | TV | radio | newspaper | sales | 
|---|---|---|---|---|
| 1 | 230.1 | 37.8 | 69.2 | 22.1 | 
| 2 | 44.5 | 39.3 | 45.1 | 10.4 | 
| 3 | 17.2 | 45.9 | 69.3 | 9.3 | 
| 4 | 151.5 | 41.3 | 58.5 | 18.5 | 
| 5 | 180.8 | 10.8 | 58.4 | 12.9 | 
| 6 | 8.7 | 48.9 | 75.0 | 7.2 | 
A scatterplot of TV vs. sales looks as follows:
ad_df |> ggplot(aes(x = TV, y = sales)) +
  geom_point() +
  theme_dsan()
Here we use lm(), also used near the beginning of ISLR, to obtain OLS estimates of the coefficients relating TV, radio, and newspaper to sales:
reg_model <- lm(
    sales ~ TV + radio + newspaper,
    data=ad_df
)
print(summary(reg_model))
Call:
lm(formula = sales ~ TV + radio + newspaper, data = ad_df)
Residuals:
    Min      1Q  Median      3Q     Max 
-8.8277 -0.8908  0.2418  1.1893  2.8292 
Coefficients:
             Estimate Std. Error t value Pr(>|t|)    
(Intercept)  2.938889   0.311908   9.422   <2e-16 ***
TV           0.045765   0.001395  32.809   <2e-16 ***
radio        0.188530   0.008611  21.893   <2e-16 ***
newspaper   -0.001037   0.005871  -0.177     0.86    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 1.686 on 196 degrees of freedom
Multiple R-squared:  0.8972,    Adjusted R-squared:  0.8956 
F-statistic: 570.3 on 3 and 196 DF,  p-value: < 2.2e-16
While we can’t really “fully” visualize the model in 2D or even 3D (since there are 3 features and 1 label, which would require a 4D visualization), we can still obtain a helpful 2D visualization that broadly resembles the above visualization of TV vs. sales.
To achieve this, we freeze two of the feature values (radio and newspaper) at their means and then plot what our model says about the relation between TV and sales at these held-constant radio and newspaper values:
# "Freeze" radio and newspaper values at their means
radio_mean <- mean(ad_df$radio)
news_mean <- mean(ad_df$newspaper)
# Define the range of TV values over which we want to plot predictions
TV_vals <- seq(0, 300, 10)
# Extract all coefficients from our model
reg_coefs <- reg_model$coef
# For every value v in TV_vals, compute prediction
# yhat(v, radio_mean, news_mean)
get_prediction <- function(TV_val) {
    intercept <- reg_coefs['(Intercept)']
    TV_term <- reg_coefs['TV'] * TV_val
    radio_term <- reg_coefs['radio'] * radio_mean
    news_term <- reg_coefs['newspaper'] * news_mean
    return(intercept + TV_term + radio_term + news_term)
}
# Compute predictions for each value of TV_vals
pred_df <- tibble(TV=TV_vals) |> mutate(
    sales_pred = get_prediction(TV)
)
ggplot() +
  geom_point(data=ad_df, aes(x=TV, y=sales)) +
  geom_line(
    data=pred_df, aes(x=TV, y=sales_pred),
    linewidth=1, color=cb_palette[2]
  ) +
  theme_dsan()
nnet for (Simple) NN Model WeightsHere, the reason I put “(Simple)” is because, for example, nnet only supports networks with either (a) no hidden layers at all, or (b) a single hidden layer.
Here, to show you how to fit NN models using nnet (without giving away the full code required for this part of the lab), I am using just the default parameter settings for the nnet() function—on the Lab itself you’ll need to read the instructions more carefully and think about how to modify this code to achieve the desired result.
library(nnet)
nn_model <- nnet(
    sales ~ TV + radio + newspaper,
    size=10,
    linout=TRUE,
    data=ad_df
)# weights:  51
initial  value 52680.989030 
iter  10 value 4003.592232
iter  20 value 3556.990620
iter  30 value 3277.830385
iter  40 value 2693.332152
iter  50 value 1442.629937
iter  60 value 713.350821
iter  70 value 419.466340
iter  80 value 314.564890
iter  90 value 202.253371
iter 100 value 84.217076
final  value 84.217076 
stopped after 100 iterations
nn_modela 3-10-1 network with 51 weights
inputs: TV radio newspaper 
output(s): sales 
options were - linear output units 
From the second part of the output (the output from just the line nn_model), you should think through why it’s called a “3-10-1 network”, and then why this architecture would require estimating 51 weights.
To visualize what’s happening, we can take the same approach we took in the previous visualization: see what our NN predicts for sales across a range of TV values, with radio and newspaper held constant at their means.
First, note that R’s predict() function takes in (1) a fitted model and (2) a data.frame where each row is a vector of values you want to generate a prediction for. So, for example, we can obtain a single prediction for a specific set of TV, radio, and newspaper values like:
predict(nn_model, data.frame(TV=10, radio=23, newspaper=30))      [,1]
1 6.698056
So, for ease-of-use with this predict() functionality, we first construct a tibble where each row represents a tuple (TV_val, radio_mean, news_mean):
nn_input_df <- data.frame(TV=TV_vals, radio=radio_mean, newspaper=news_mean)
as.data.frame(nn_input_df)And now, by plugging this tibble into predict(), we obtain our NN’s prediction for the inputs in each row:
nn_pred_df <- nn_input_df
nn_pred_df$sales_pred <- predict(nn_model, nn_input_df)
as.data.frame(nn_pred_df)Which we can visualize using the same approach we used for the linear model above (the non-linearity is subtle, but we can see the line varying in a way that a straight line \(y = mx + b\) would not!)
ggplot() +
  geom_point(data=ad_df, aes(x=TV, y=sales)) +
  geom_line(
    data=nn_pred_df, aes(x=TV, y=sales_pred),
    linewidth=1, color=cb_palette[2]
  ) +
  theme_dsan()
Keras is a more complex, heavy-duty neural network library, but for the purposes of the lab (showing how models like logistic regression can be “reconceptualized” as simple neural networks) the simpler nnet library has a less-steep learning curve!↩︎