Code
source("../../dsan-globals/_globals.r")
DSAN 5300: Statistical Learning
source("../../dsan-globals/_globals.r")
nnet
in RSince the ISLR lab within the Deep Learning chapter uses R’s keras
library rather than nnet
1, here are some quick examples of how nnet
works that may help you get started on the R portion of Lab 9.
Here we’ll load the Advertising.csv
dataset used in the beginning of ISLR:
library(tidyverse) |> suppressPackageStartupMessages()
<- read_csv("https://www.statlearning.com/s/Advertising.csv", show_col_types = FALSE) ad_df
New names:
• `` -> `...1`
colnames(ad_df) <- c("id", colnames(ad_df)[2:5])
|> head() ad_df
id | TV | radio | newspaper | sales |
---|---|---|---|---|
1 | 230.1 | 37.8 | 69.2 | 22.1 |
2 | 44.5 | 39.3 | 45.1 | 10.4 |
3 | 17.2 | 45.9 | 69.3 | 9.3 |
4 | 151.5 | 41.3 | 58.5 | 18.5 |
5 | 180.8 | 10.8 | 58.4 | 12.9 |
6 | 8.7 | 48.9 | 75.0 | 7.2 |
A scatterplot of TV
vs. sales
looks as follows:
|> ggplot(aes(x = TV, y = sales)) +
ad_df geom_point() +
theme_dsan()
Here we use lm()
, also used near the beginning of ISLR, to obtain OLS estimates of the coefficients relating TV
, radio
, and newspaper
to sales
:
<- lm(
reg_model ~ TV + radio + newspaper,
sales data=ad_df
)print(summary(reg_model))
Call:
lm(formula = sales ~ TV + radio + newspaper, data = ad_df)
Residuals:
Min 1Q Median 3Q Max
-8.8277 -0.8908 0.2418 1.1893 2.8292
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 2.938889 0.311908 9.422 <2e-16 ***
TV 0.045765 0.001395 32.809 <2e-16 ***
radio 0.188530 0.008611 21.893 <2e-16 ***
newspaper -0.001037 0.005871 -0.177 0.86
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 1.686 on 196 degrees of freedom
Multiple R-squared: 0.8972, Adjusted R-squared: 0.8956
F-statistic: 570.3 on 3 and 196 DF, p-value: < 2.2e-16
While we can’t really “fully” visualize the model in 2D or even 3D (since there are 3 features and 1 label, which would require a 4D visualization), we can still obtain a helpful 2D visualization that broadly resembles the above visualization of TV
vs. sales
.
To achieve this, we freeze two of the feature values (radio
and newspaper
) at their means and then plot what our model says about the relation between TV
and sales
at these held-constant radio
and newspaper
values:
# "Freeze" radio and newspaper values at their means
<- mean(ad_df$radio)
radio_mean <- mean(ad_df$newspaper)
news_mean # Define the range of TV values over which we want to plot predictions
<- seq(0, 300, 10)
TV_vals # Extract all coefficients from our model
<- reg_model$coef
reg_coefs # For every value v in TV_vals, compute prediction
# yhat(v, radio_mean, news_mean)
<- function(TV_val) {
get_prediction <- reg_coefs['(Intercept)']
intercept <- reg_coefs['TV'] * TV_val
TV_term <- reg_coefs['radio'] * radio_mean
radio_term <- reg_coefs['newspaper'] * news_mean
news_term return(intercept + TV_term + radio_term + news_term)
}# Compute predictions for each value of TV_vals
<- tibble(TV=TV_vals) |> mutate(
pred_df sales_pred = get_prediction(TV)
)ggplot() +
geom_point(data=ad_df, aes(x=TV, y=sales)) +
geom_line(
data=pred_df, aes(x=TV, y=sales_pred),
linewidth=1, color=cb_palette[2]
+
) theme_dsan()
nnet
for (Simple) NN Model WeightsHere, the reason I put “(Simple)” is because, for example, nnet
only supports networks with either (a) no hidden layers at all, or (b) a single hidden layer.
Here, to show you how to fit NN models using nnet
(without giving away the full code required for this part of the lab), I am using just the default parameter settings for the nnet()
function—on the Lab itself you’ll need to read the instructions more carefully and think about how to modify this code to achieve the desired result.
library(nnet)
<- nnet(
nn_model ~ TV + radio + newspaper,
sales size=10,
linout=TRUE,
data=ad_df
)
# weights: 51
initial value 52680.989030
iter 10 value 4003.592232
iter 20 value 3556.990620
iter 30 value 3277.830385
iter 40 value 2693.332152
iter 50 value 1442.629937
iter 60 value 713.350821
iter 70 value 419.466340
iter 80 value 314.564890
iter 90 value 202.253371
iter 100 value 84.217076
final value 84.217076
stopped after 100 iterations
nn_model
a 3-10-1 network with 51 weights
inputs: TV radio newspaper
output(s): sales
options were - linear output units
From the second part of the output (the output from just the line nn_model
), you should think through why it’s called a “3-10-1 network”, and then why this architecture would require estimating 51 weights.
To visualize what’s happening, we can take the same approach we took in the previous visualization: see what our NN predicts for sales
across a range of TV
values, with radio
and newspaper
held constant at their means.
First, note that R’s predict()
function takes in (1) a fitted model and (2) a data.frame
where each row is a vector of values you want to generate a prediction for. So, for example, we can obtain a single prediction for a specific set of TV
, radio
, and newspaper
values like:
predict(nn_model, data.frame(TV=10, radio=23, newspaper=30))
[,1]
1 6.698056
So, for ease-of-use with this predict()
functionality, we first construct a tibble
where each row represents a tuple (TV_val, radio_mean, news_mean)
:
<- data.frame(TV=TV_vals, radio=radio_mean, newspaper=news_mean)
nn_input_df as.data.frame(nn_input_df)
And now, by plugging this tibble
into predict()
, we obtain our NN’s prediction for the inputs in each row:
<- nn_input_df
nn_pred_df $sales_pred <- predict(nn_model, nn_input_df)
nn_pred_dfas.data.frame(nn_pred_df)
Which we can visualize using the same approach we used for the linear model above (the non-linearity is subtle, but we can see the line varying in a way that a straight line \(y = mx + b\) would not!)
ggplot() +
geom_point(data=ad_df, aes(x=TV, y=sales)) +
geom_line(
data=nn_pred_df, aes(x=TV, y=sales_pred),
linewidth=1, color=cb_palette[2]
+
) theme_dsan()
Keras is a more complex, heavy-duty neural network library, but for the purposes of the lab (showing how models like logistic regression can be “reconceptualized” as simple neural networks) the simpler nnet
library has a less-steep learning curve!↩︎